Add back streaming requests.
This commit is contained in:
parent
3d05e1ec07
commit
ef4f233fba
|
@ -953,8 +953,8 @@ class Sanic:
|
||||||
handler, args, kwargs, uri, name = self.router.get(request)
|
handler, args, kwargs, uri, name = self.router.get(request)
|
||||||
|
|
||||||
# Non-streaming handlers have their body preloaded
|
# Non-streaming handlers have their body preloaded
|
||||||
#if not self.router.is_stream_handler(request):
|
if request.stream and not self.router.is_stream_handler(request):
|
||||||
# await request.receive_body()
|
await request.receive_body()
|
||||||
|
|
||||||
# -------------------------------------------- #
|
# -------------------------------------------- #
|
||||||
# Request Middleware
|
# Request Middleware
|
||||||
|
|
|
@ -54,7 +54,7 @@ class StreamBuffer:
|
||||||
async def read(self):
|
async def read(self):
|
||||||
""" Stop reading when gets None """
|
""" Stop reading when gets None """
|
||||||
if self._protocol:
|
if self._protocol:
|
||||||
self._protocol.flow_control()
|
return await self._protocol.request_body()
|
||||||
payload = await self._queue.get()
|
payload = await self._queue.get()
|
||||||
self._queue.task_done()
|
self._queue.task_done()
|
||||||
return payload
|
return payload
|
||||||
|
@ -119,7 +119,7 @@ class Request:
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
|
|
||||||
# Init but do not inhale
|
# Init but do not inhale
|
||||||
self.body = None
|
self.body = b""
|
||||||
self.ctx = SimpleNamespace()
|
self.ctx = SimpleNamespace()
|
||||||
self.parsed_forwarded = None
|
self.parsed_forwarded = None
|
||||||
self.parsed_json = None
|
self.parsed_json = None
|
||||||
|
|
|
@ -225,14 +225,17 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
raise PayloadTooLarge("Payload Too Large")
|
raise PayloadTooLarge("Payload Too Large")
|
||||||
|
|
||||||
self._total_request_size = pos + 4
|
self._total_request_size = pos + 4
|
||||||
|
try:
|
||||||
reqline, *headers = buf[:pos].decode().split("\r\n")
|
reqline, *headers = buf[:pos].decode().split("\r\n")
|
||||||
method, self.url, protocol = reqline.split(" ")
|
method, self.url, protocol = reqline.split(" ")
|
||||||
assert protocol in ("HTTP/1.0", "HTTP/1.1")
|
assert protocol in ("HTTP/1.0", "HTTP/1.1")
|
||||||
del buf[:pos + 4]
|
|
||||||
headers = Header(
|
headers = Header(
|
||||||
(name.lower(), value.lstrip())
|
(name.lower(), value.lstrip())
|
||||||
for name, value in (h.split(":", 1) for h in headers)
|
for name, value in (h.split(":", 1) for h in headers)
|
||||||
)
|
)
|
||||||
|
except:
|
||||||
|
raise InvalidUsage("Bad Request")
|
||||||
|
del buf[:pos + 4]
|
||||||
if headers.get(EXPECT_HEADER):
|
if headers.get(EXPECT_HEADER):
|
||||||
self._status = Status.EXPECT
|
self._status = Status.EXPECT
|
||||||
self.expect_handler()
|
self.expect_handler()
|
||||||
|
@ -246,7 +249,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
transport=self.transport,
|
transport=self.transport,
|
||||||
app=self.app,
|
app=self.app,
|
||||||
)
|
)
|
||||||
request_body = (
|
self.state["request_body"] = request_body = (
|
||||||
int(headers.get("content-length", 0)) or
|
int(headers.get("content-length", 0)) or
|
||||||
headers.get("transfer-encoding") == "chunked"
|
headers.get("transfer-encoding") == "chunked"
|
||||||
)
|
)
|
||||||
|
@ -254,8 +257,6 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
self.request.stream = StreamBuffer(
|
self.request.stream = StreamBuffer(
|
||||||
self.request_buffer_queue_size, protocol=self,
|
self.request_buffer_queue_size, protocol=self,
|
||||||
)
|
)
|
||||||
self.request.stream._queue.put_nowait(None)
|
|
||||||
|
|
||||||
self._status, self._time = Status.HANDLER, current_time()
|
self._status, self._time = Status.HANDLER, current_time()
|
||||||
await self.request_handler(
|
await self.request_handler(
|
||||||
self.request, self.write_response, self.stream_response
|
self.request, self.write_response, self.stream_response
|
||||||
|
@ -268,6 +269,40 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
finally:
|
finally:
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
|
async def request_body(self):
|
||||||
|
rb = self.state["request_body"]
|
||||||
|
if rb is True:
|
||||||
|
# This code is crap and needs rewriting
|
||||||
|
while b"\r\n" not in self._buffer:
|
||||||
|
await self.receive_more()
|
||||||
|
pos = self._buffer.find(b"\r\n")
|
||||||
|
size = int(self._buffer[:pos])
|
||||||
|
if self._total_request_size + size > self.request_max_size:
|
||||||
|
self.keep_alive = False
|
||||||
|
raise PayloadTooLarge("Payload Too Large")
|
||||||
|
self._total_request_size += pos + 4 + size
|
||||||
|
if size == 0:
|
||||||
|
self.state["request_body"] = 0
|
||||||
|
return None
|
||||||
|
while len(self._buffer) < pos + 4 + size:
|
||||||
|
await self.receive_more()
|
||||||
|
body = self._buffer[pos + 2: pos + 2 + size]
|
||||||
|
del body[:pos + 4 + size]
|
||||||
|
return body
|
||||||
|
elif rb > 0:
|
||||||
|
if not self._buffer:
|
||||||
|
await self.receive_more()
|
||||||
|
body = self._buffer[:rb]
|
||||||
|
size = len(body)
|
||||||
|
del self._buffer[:rb]
|
||||||
|
self.state["request_body"] -= size
|
||||||
|
self._total_request_size += size
|
||||||
|
if self._total_request_size > self.request_max_size:
|
||||||
|
self.keep_alive = False
|
||||||
|
raise PayloadTooLarge("Payload Too Large")
|
||||||
|
return body
|
||||||
|
return None
|
||||||
|
|
||||||
def expect_handler(self):
|
def expect_handler(self):
|
||||||
"""
|
"""
|
||||||
Handler for Expect Header.
|
Handler for Expect Header.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user