Cleanup of code and avoid mixing streaming responses.
This commit is contained in:
parent
fe64a2764d
commit
f609a4850f
|
@ -47,11 +47,14 @@ class RequestParameters(dict):
|
|||
|
||||
|
||||
class StreamBuffer:
|
||||
def __init__(self, buffer_size=100):
|
||||
def __init__(self, buffer_size=100, protocol=None):
|
||||
self._queue = asyncio.Queue(buffer_size)
|
||||
self._protocol = protocol
|
||||
|
||||
async def read(self):
|
||||
""" Stop reading when gets None """
|
||||
if self._protocol:
|
||||
self._protocol.flow_control()
|
||||
payload = await self._queue.get()
|
||||
self._queue.task_done()
|
||||
return payload
|
||||
|
|
167
sanic/server.py
167
sanic/server.py
|
@ -82,13 +82,12 @@ class HttpProtocol(asyncio.Protocol):
|
|||
"_last_response_time",
|
||||
"_is_stream_handler",
|
||||
"_not_paused",
|
||||
"_request_handler_task",
|
||||
"_request_stream_task",
|
||||
"_keep_alive",
|
||||
"_header_fragment",
|
||||
"state",
|
||||
"_debug",
|
||||
"_body_chunks",
|
||||
"_handler_queue",
|
||||
"_handler_task",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -112,13 +111,13 @@ class HttpProtocol(asyncio.Protocol):
|
|||
router=None,
|
||||
state=None,
|
||||
debug=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.loop = loop
|
||||
self.app = app
|
||||
self.transport = None
|
||||
self.request = None
|
||||
self.parser = None
|
||||
self.parser = HttpRequestParser(self)
|
||||
self.url = None
|
||||
self.headers = None
|
||||
self.router = router
|
||||
|
@ -145,8 +144,6 @@ class HttpProtocol(asyncio.Protocol):
|
|||
self._keep_alive_timeout_handler = None
|
||||
self._last_request_time = None
|
||||
self._last_response_time = None
|
||||
self._request_handler_task = None
|
||||
self._request_stream_task = None
|
||||
self._keep_alive = keep_alive
|
||||
self._header_fragment = b""
|
||||
self.state = state if state else {}
|
||||
|
@ -154,7 +151,38 @@ class HttpProtocol(asyncio.Protocol):
|
|||
self.state["requests_count"] = 0
|
||||
self._debug = debug
|
||||
self._not_paused.set()
|
||||
self._body_chunks = deque()
|
||||
self._handler_queue = asyncio.Queue(1)
|
||||
self._handler_task = self.loop.create_task(self.run_handlers())
|
||||
|
||||
async def run_handlers(self):
|
||||
"""Runs one handler at a time, in correct order.
|
||||
|
||||
Otherwise we cannot control incoming requests, this keeps their
|
||||
responses in order.
|
||||
|
||||
Handlers are inserted into the queue when headers are received. Body
|
||||
may follow later and is handled by a separate queue in req.stream.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
self.flow_control()
|
||||
handler = await self._handler_queue.get()
|
||||
await handler
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
self.transport.close()
|
||||
raise
|
||||
|
||||
def flow_control(self):
|
||||
"""Backpressure handling to avoid excessive buffering."""
|
||||
if not self.transport:
|
||||
return
|
||||
if self._handler_queue.full() or (
|
||||
self.request and self.request.stream.is_full()
|
||||
):
|
||||
self.transport.pause_reading()
|
||||
else:
|
||||
self.transport.resume_reading()
|
||||
|
||||
@property
|
||||
def keep_alive(self):
|
||||
|
@ -185,10 +213,8 @@ class HttpProtocol(asyncio.Protocol):
|
|||
|
||||
def connection_lost(self, exc):
|
||||
self.connections.discard(self)
|
||||
if self._request_handler_task:
|
||||
self._request_handler_task.cancel()
|
||||
if self._request_stream_task:
|
||||
self._request_stream_task.cancel()
|
||||
if self._handler_task:
|
||||
self._handler_task.cancel()
|
||||
if self._request_timeout_handler:
|
||||
self._request_timeout_handler.cancel()
|
||||
if self._response_timeout_handler:
|
||||
|
@ -214,10 +240,8 @@ class HttpProtocol(asyncio.Protocol):
|
|||
time_left, self.request_timeout_callback
|
||||
)
|
||||
else:
|
||||
if self._request_stream_task:
|
||||
self._request_stream_task.cancel()
|
||||
if self._request_handler_task:
|
||||
self._request_handler_task.cancel()
|
||||
if self._handler_task:
|
||||
self._handler_task.cancel()
|
||||
self.write_error(RequestTimeout("Request Timeout"))
|
||||
|
||||
def response_timeout_callback(self):
|
||||
|
@ -230,10 +254,8 @@ class HttpProtocol(asyncio.Protocol):
|
|||
time_left, self.response_timeout_callback
|
||||
)
|
||||
else:
|
||||
if self._request_stream_task:
|
||||
self._request_stream_task.cancel()
|
||||
if self._request_handler_task:
|
||||
self._request_handler_task.cancel()
|
||||
if self._handler_task:
|
||||
self._handler_task.cancel()
|
||||
self.write_error(ServiceUnavailable("Response Timeout"))
|
||||
|
||||
def keep_alive_timeout_callback(self):
|
||||
|
@ -266,15 +288,6 @@ class HttpProtocol(asyncio.Protocol):
|
|||
if self._total_request_size > self.request_max_size:
|
||||
self.write_error(PayloadTooLarge("Payload Too Large"))
|
||||
|
||||
# Create parser if this is the first time we're receiving data
|
||||
if self.parser is None:
|
||||
assert self.request is None
|
||||
self.headers = []
|
||||
self.parser = HttpRequestParser(self)
|
||||
|
||||
# requests count
|
||||
self.state["requests_count"] = self.state["requests_count"] + 1
|
||||
|
||||
# Parse request chunk or close connection
|
||||
try:
|
||||
self.parser.feed_data(data)
|
||||
|
@ -284,6 +297,13 @@ class HttpProtocol(asyncio.Protocol):
|
|||
message += "\n" + traceback.format_exc()
|
||||
self.write_error(InvalidUsage(message))
|
||||
|
||||
def on_message_begin(self):
|
||||
assert self.request is None
|
||||
self.headers = []
|
||||
|
||||
# requests count
|
||||
self.state["requests_count"] += 1
|
||||
|
||||
def on_url(self, url):
|
||||
if not self.url:
|
||||
self.url = url
|
||||
|
@ -327,8 +347,9 @@ class HttpProtocol(asyncio.Protocol):
|
|||
if self.request.headers.get(EXPECT_HEADER):
|
||||
self.expect_handler()
|
||||
|
||||
if self.is_request_stream:
|
||||
self.request.stream = StreamBuffer(self.request_buffer_queue_size)
|
||||
self.request.stream = StreamBuffer(
|
||||
self.request_buffer_queue_size, protocol=self,
|
||||
)
|
||||
self.execute_request_handler()
|
||||
|
||||
def expect_handler(self):
|
||||
|
@ -347,44 +368,10 @@ class HttpProtocol(asyncio.Protocol):
|
|||
)
|
||||
|
||||
def on_body(self, body):
|
||||
# body chunks can be put into asyncio.Queue out of order if
|
||||
# multiple tasks put concurrently and the queue is full in python
|
||||
# 3.7. so we should not create more than one task putting into the
|
||||
# queue simultaneously.
|
||||
self._body_chunks.append(body)
|
||||
if (
|
||||
not self._request_stream_task
|
||||
or self._request_stream_task.done()
|
||||
):
|
||||
self._request_stream_task = self.loop.create_task(
|
||||
self.stream_append()
|
||||
)
|
||||
|
||||
async def body_append(self, body):
|
||||
if (
|
||||
self.request is None
|
||||
or self._request_stream_task is None
|
||||
or self._request_stream_task.cancelled()
|
||||
):
|
||||
return
|
||||
|
||||
if self.request.stream.is_full():
|
||||
self.transport.pause_reading()
|
||||
await self.request.stream.put(body)
|
||||
self.transport.resume_reading()
|
||||
else:
|
||||
await self.request.stream.put(body)
|
||||
|
||||
async def stream_append(self):
|
||||
while self._body_chunks and self.request:
|
||||
body = self._body_chunks.popleft()
|
||||
if self.request.stream.is_full():
|
||||
self.transport.pause_reading()
|
||||
await self.request.stream.put(body)
|
||||
self.transport.resume_reading()
|
||||
else:
|
||||
await self.request.stream.put(body)
|
||||
self._body_chunks.clear()
|
||||
# Send request body chunks
|
||||
if body:
|
||||
self.request.stream._queue.put_nowait(body)
|
||||
self.flow_control()
|
||||
|
||||
def on_message_complete(self):
|
||||
# Entire request (headers and whole body) is received.
|
||||
|
@ -393,14 +380,9 @@ class HttpProtocol(asyncio.Protocol):
|
|||
self._request_timeout_handler.cancel()
|
||||
self._request_timeout_handler = None
|
||||
|
||||
self._body_chunks.append(None)
|
||||
if (
|
||||
not self._request_stream_task
|
||||
or self._request_stream_task.done()
|
||||
):
|
||||
self._request_stream_task = self.loop.create_task(
|
||||
self.stream_append()
|
||||
)
|
||||
# Mark the end of request body
|
||||
self.request.stream._queue.put_nowait(None)
|
||||
self.request = None
|
||||
|
||||
def execute_request_handler(self):
|
||||
"""
|
||||
|
@ -409,15 +391,17 @@ class HttpProtocol(asyncio.Protocol):
|
|||
|
||||
:return: None
|
||||
"""
|
||||
# Invoked when request headers have been received
|
||||
self._response_timeout_handler = self.loop.call_later(
|
||||
self.response_timeout, self.response_timeout_callback
|
||||
)
|
||||
self._last_request_time = time()
|
||||
self._request_handler_task = self.loop.create_task(
|
||||
self._handler_queue.put_nowait(
|
||||
self.request_handler(
|
||||
self.request, self.write_response, self.stream_response
|
||||
)
|
||||
)
|
||||
self.flow_control()
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Responding
|
||||
|
@ -467,17 +451,17 @@ class HttpProtocol(asyncio.Protocol):
|
|||
try:
|
||||
keep_alive = self.keep_alive
|
||||
self.transport.write(
|
||||
response.output(
|
||||
self.request.version, keep_alive, self.keep_alive_timeout
|
||||
)
|
||||
response.output("1.1", keep_alive, self.keep_alive_timeout)
|
||||
)
|
||||
self.log_response(response)
|
||||
except AttributeError:
|
||||
if isinstance(response, HTTPResponse):
|
||||
raise
|
||||
url = self.url.decode()
|
||||
res_type = type(response).__name__
|
||||
logger.error(
|
||||
"Invalid response object for url %s, "
|
||||
"Expected Type: HTTPResponse, Actual Type: %s",
|
||||
self.url,
|
||||
type(response),
|
||||
f"Invalid response object for url {url!r}, "
|
||||
f"Expected Type: HTTPResponse, Actual Type: {res_type}"
|
||||
)
|
||||
self.write_error(ServerError("Invalid response type"))
|
||||
except RuntimeError:
|
||||
|
@ -521,9 +505,7 @@ class HttpProtocol(asyncio.Protocol):
|
|||
try:
|
||||
keep_alive = self.keep_alive
|
||||
response.protocol = self
|
||||
await response.stream(
|
||||
self.request.version, keep_alive, self.keep_alive_timeout
|
||||
)
|
||||
await response.stream("1.1", keep_alive, self.keep_alive_timeout)
|
||||
self.log_response(response)
|
||||
except AttributeError:
|
||||
logger.error(
|
||||
|
@ -564,8 +546,7 @@ class HttpProtocol(asyncio.Protocol):
|
|||
response = None
|
||||
try:
|
||||
response = self.error_handler.response(self.request, exception)
|
||||
version = self.request.version if self.request else "1.1"
|
||||
self.transport.write(response.output(version))
|
||||
self.transport.write(response.output("1.1"))
|
||||
except RuntimeError:
|
||||
if self._debug:
|
||||
logger.error(
|
||||
|
@ -621,12 +602,8 @@ class HttpProtocol(asyncio.Protocol):
|
|||
"""This is called when KeepAlive feature is used,
|
||||
it resets the connection in order for it to be able
|
||||
to handle receiving another request on the same connection."""
|
||||
self.parser = None
|
||||
self.request = None
|
||||
self.url = None
|
||||
self.headers = None
|
||||
self._request_handler_task = None
|
||||
self._request_stream_task = None
|
||||
self._total_request_size = 0
|
||||
|
||||
def close_if_idle(self):
|
||||
|
@ -888,7 +865,7 @@ def serve(
|
|||
reuse_port=reuse_port,
|
||||
sock=sock,
|
||||
backlog=backlog,
|
||||
**asyncio_server_kwargs
|
||||
**asyncio_server_kwargs,
|
||||
)
|
||||
|
||||
if run_async:
|
||||
|
|
|
@ -73,7 +73,6 @@ class SanicTestClient:
|
|||
):
|
||||
results = [None, None]
|
||||
exceptions = []
|
||||
|
||||
if gather_request:
|
||||
|
||||
def _collect_request(request):
|
||||
|
|
Loading…
Reference in New Issue
Block a user