fix: handle expect header

This commit is contained in:
Yun Xu 2019-06-03 22:08:24 -07:00
parent c15158224b
commit f21db60859
4 changed files with 72 additions and 2 deletions

View File

@ -218,6 +218,11 @@ class ContentRangeError(SanicException):
} }
@add_status_code(417)
class HeaderExpectationFailed(SanicException):
pass
@add_status_code(403) @add_status_code(403)
class Forbidden(SanicException): class Forbidden(SanicException):
pass pass

View File

@ -29,7 +29,7 @@ except ImportError:
DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream" DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream"
EXPECT_HEADER = "EXPECT"
# HTTP/1.1: https://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1 # HTTP/1.1: https://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1
# > If the media type remains unknown, the recipient SHOULD treat it # > If the media type remains unknown, the recipient SHOULD treat it

View File

@ -15,6 +15,7 @@ from httptools.parser.errors import HttpParserError
from multidict import CIMultiDict from multidict import CIMultiDict
from sanic.exceptions import ( from sanic.exceptions import (
HeaderExpectationFailed,
InvalidUsage, InvalidUsage,
PayloadTooLarge, PayloadTooLarge,
RequestTimeout, RequestTimeout,
@ -22,7 +23,7 @@ from sanic.exceptions import (
ServiceUnavailable, ServiceUnavailable,
) )
from sanic.log import access_logger, logger from sanic.log import access_logger, logger
from sanic.request import Request, StreamBuffer from sanic.request import Request, StreamBuffer, EXPECT_HEADER
from sanic.response import HTTPResponse from sanic.response import HTTPResponse
@ -314,6 +315,10 @@ class HttpProtocol(asyncio.Protocol):
if self._keep_alive_timeout_handler: if self._keep_alive_timeout_handler:
self._keep_alive_timeout_handler.cancel() self._keep_alive_timeout_handler.cancel()
self._keep_alive_timeout_handler = None self._keep_alive_timeout_handler = None
if self.request.headers.get(EXPECT_HEADER):
self.expect_handler()
if self.is_request_stream: if self.is_request_stream:
self._is_stream_handler = self.router.is_stream_handler( self._is_stream_handler = self.router.is_stream_handler(
self.request self.request
@ -324,6 +329,17 @@ class HttpProtocol(asyncio.Protocol):
) )
self.execute_request_handler() self.execute_request_handler()
def expect_handler(self):
"""
Handler for Expect Header.
"""
expect = self.request.headers.get(EXPECT_HEADER)
if self.request.version == "1.1":
if expect.lower() == "100-continue":
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
else:
self.write_error(HeaderExpectationFailed("Unknow Expect: {expect}".format(expect=expect)))
def on_body(self, body): def on_body(self, body):
if self.is_request_stream and self._is_stream_handler: if self.is_request_stream and self._is_stream_handler:
self._request_stream_task = self.loop.create_task( self._request_stream_task = self.loop.create_task(

View File

@ -1,4 +1,6 @@
import pytest
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint
from sanic.exceptions import HeaderExpectationFailed
from sanic.request import StreamBuffer from sanic.request import StreamBuffer
from sanic.response import stream, text from sanic.response import stream, text
from sanic.views import CompositionView, HTTPMethodView from sanic.views import CompositionView, HTTPMethodView
@ -40,6 +42,53 @@ def test_request_stream_method_view(app):
assert response.text == data assert response.text == data
def test_request_stream_100_continue(app):
class SimpleView(HTTPMethodView):
@stream_decorator
async def post(self, request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
if body is None:
break
result += body.decode("utf-8")
return text(result)
app.add_route(SimpleView.as_view(), "/method_view")
assert app.is_request_stream is True
request, response = app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue"})
assert response.status == 200
assert response.text == data
def test_request_stream_100_continue_raise_HeaderExpectationFailed(app):
class SimpleView(HTTPMethodView):
@stream_decorator
async def post(self, request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
if body is None:
break
result += body.decode("utf-8")
return text(result)
app.add_route(SimpleView.as_view(), "/method_view")
assert app.is_request_stream is True
with pytest.raises(ValueError) as e:
app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue-extra"})
assert "Unknow Expect: 100-continue-extra" in str(e)
def test_request_stream_app(app): def test_request_stream_app(app):
"""for self.is_request_stream = True and decorators""" """for self.is_request_stream = True and decorators"""