Code cleanup, 14 tests failing.
This commit is contained in:
parent
8a1baeb9d5
commit
57202bfa89
|
@ -20,7 +20,7 @@ import sanic.app # noqa
|
|||
from sanic.compat import Header
|
||||
from sanic.exceptions import InvalidUsage, ServerError
|
||||
from sanic.log import logger
|
||||
from sanic.request import Request, StreamBuffer
|
||||
from sanic.request import Request
|
||||
from sanic.response import HTTPResponse, StreamingHTTPResponse
|
||||
from sanic.websocket import WebSocketConnection
|
||||
|
||||
|
@ -250,11 +250,11 @@ class ASGIApp:
|
|||
instance.transport,
|
||||
sanic_app,
|
||||
)
|
||||
instance.request.stream = StreamBuffer(protocol=instance)
|
||||
instance.request.stream = instance
|
||||
|
||||
return instance
|
||||
|
||||
async def stream_body(self) -> None:
|
||||
async def read(self) -> None:
|
||||
"""
|
||||
Read and stream the body in chunks from an incoming ASGI message.
|
||||
"""
|
||||
|
@ -263,6 +263,13 @@ class ASGIApp:
|
|||
return None
|
||||
return message.get("body", b"")
|
||||
|
||||
async def __aiter__(self):
|
||||
while True:
|
||||
data = await self.read()
|
||||
if not data:
|
||||
return
|
||||
yield data
|
||||
|
||||
def respond(self, response):
|
||||
headers: List[Tuple[bytes, bytes]] = []
|
||||
cookies: Dict[str, str] = {}
|
||||
|
|
|
@ -17,7 +17,7 @@ from sanic.response import HTTPResponse
|
|||
from sanic.compat import Header
|
||||
|
||||
|
||||
class Lifespan(Enum):
|
||||
class Stage(Enum):
|
||||
IDLE = 0 # Waiting for request
|
||||
REQUEST = 1 # Request headers being received
|
||||
HANDLER = 3 # Headers done, handler running
|
||||
|
@ -29,20 +29,21 @@ HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n"
|
|||
|
||||
class Http:
|
||||
def __init__(self, protocol):
|
||||
self._send = protocol.push_data
|
||||
self._send = protocol.send
|
||||
self._receive_more = protocol.receive_more
|
||||
self.recv_buffer = protocol.recv_buffer
|
||||
self.protocol = protocol
|
||||
self.recv_buffer = bytearray()
|
||||
self.expecting_continue = False
|
||||
# Note: connections are initially in request mode and do not obey
|
||||
# keep-alive timeout like with some other servers.
|
||||
self.lifespan = Lifespan.REQUEST
|
||||
self.stage = Stage.REQUEST
|
||||
self.keep_alive = True
|
||||
self.head_only = None
|
||||
|
||||
async def http1(self):
|
||||
"""HTTP 1.1 connection handler"""
|
||||
buf = self.recv_buffer
|
||||
self.keep_alive = True
|
||||
url = None
|
||||
self.url = None
|
||||
while self.keep_alive:
|
||||
# Read request header
|
||||
pos = 0
|
||||
|
@ -53,17 +54,17 @@ class Http:
|
|||
break
|
||||
pos = max(0, len(buf) - 3)
|
||||
await self._receive_more()
|
||||
if self.lifespan is Lifespan.IDLE:
|
||||
self.lifespan = Lifespan.REQUEST
|
||||
if self.stage is Stage.IDLE:
|
||||
self.stage = Stage.REQUEST
|
||||
else:
|
||||
self.lifespan = Lifespan.HANDLER
|
||||
self.stage = Stage.HANDLER
|
||||
raise PayloadTooLarge("Payload Too Large")
|
||||
|
||||
self.protocol._total_request_size = pos + 4
|
||||
|
||||
try:
|
||||
reqline, *headers = buf[:pos].decode().split("\r\n")
|
||||
method, url, protocol = reqline.split(" ")
|
||||
method, self.url, protocol = reqline.split(" ")
|
||||
if protocol not in ("HTTP/1.0", "HTTP/1.1"):
|
||||
raise Exception
|
||||
self.head_only = method.upper() == "HEAD"
|
||||
|
@ -72,12 +73,12 @@ class Http:
|
|||
for name, value in (h.split(":", 1) for h in headers)
|
||||
)
|
||||
except:
|
||||
self.lifespan = Lifespan.HANDLER
|
||||
self.stage = Stage.HANDLER
|
||||
raise InvalidUsage("Bad Request")
|
||||
|
||||
# Prepare a request object from the header received
|
||||
request = self.protocol.request_class(
|
||||
url_bytes=url.encode(),
|
||||
url_bytes=self.url.encode(),
|
||||
headers=headers,
|
||||
version=protocol[-3:],
|
||||
method=method,
|
||||
|
@ -86,8 +87,6 @@ class Http:
|
|||
)
|
||||
request.stream = self
|
||||
self.protocol.state["requests_count"] += 1
|
||||
self.protocol.url = url
|
||||
self.protocol.request = request
|
||||
self.keep_alive = (
|
||||
protocol == "HTTP/1.1"
|
||||
or headers.get("connection", "").lower() == "keep-alive"
|
||||
|
@ -98,7 +97,7 @@ class Http:
|
|||
)
|
||||
self.request_chunked = False
|
||||
self.request_bytes_left = 0
|
||||
self.lifespan = Lifespan.HANDLER
|
||||
self.stage = Stage.HANDLER
|
||||
if body:
|
||||
expect = headers.get("expect")
|
||||
if expect:
|
||||
|
@ -121,14 +120,14 @@ class Http:
|
|||
except Exception:
|
||||
logger.exception("Uncaught from app/handler")
|
||||
await self.write_error(ServerError("Internal Server Error"))
|
||||
if self.lifespan is Lifespan.IDLE:
|
||||
if self.stage is Stage.IDLE:
|
||||
continue
|
||||
|
||||
if self.lifespan is Lifespan.HANDLER:
|
||||
if self.stage is Stage.HANDLER:
|
||||
await self.respond(HTTPResponse(status=204)).send(end_stream=True)
|
||||
|
||||
# Finish sending a response (if no error)
|
||||
elif self.lifespan is Lifespan.RESPONSE:
|
||||
elif self.stage is Stage.RESPONSE:
|
||||
await self.send(end_stream=True)
|
||||
|
||||
# Consume any remaining request body
|
||||
|
@ -139,10 +138,10 @@ class Http:
|
|||
while await self.read():
|
||||
pass
|
||||
|
||||
self.lifespan = Lifespan.IDLE
|
||||
self.stage = Stage.IDLE
|
||||
|
||||
async def write_error(self, e):
|
||||
if self.lifespan is Lifespan.HANDLER:
|
||||
if self.stage is Stage.HANDLER:
|
||||
try:
|
||||
response = HTTPResponse(f"{e}", e.status_code, content_type="text/plain")
|
||||
await self.respond(response).send(end_stream=True)
|
||||
|
@ -210,8 +209,8 @@ class Http:
|
|||
Nothing is sent until the first send() call on the returned object, and
|
||||
calling this function multiple times will just alter the response to be
|
||||
given."""
|
||||
if self.lifespan is not Lifespan.HANDLER:
|
||||
self.lifespan = Lifespan.FAILED
|
||||
if self.stage is not Stage.HANDLER:
|
||||
self.stage = Stage.FAILED
|
||||
raise RuntimeError("Response already started")
|
||||
if not isinstance(response.status, int) or response.status < 200:
|
||||
raise RuntimeError(f"Invalid response status {response.status!r}")
|
||||
|
@ -240,7 +239,7 @@ class Http:
|
|||
size = len(data) if data is not None else 0
|
||||
|
||||
# Headers not yet sent?
|
||||
if self.lifespan is Lifespan.HANDLER:
|
||||
if self.stage is Stage.HANDLER:
|
||||
if self.response.body:
|
||||
data = self.response.body + data if data else self.response.body
|
||||
size = len(data)
|
||||
|
@ -297,14 +296,14 @@ class Http:
|
|||
if status != 417:
|
||||
ret = HTTP_CONTINUE + ret
|
||||
# Send response
|
||||
self.lifespan = Lifespan.IDLE if end_stream else Lifespan.RESPONSE
|
||||
self.stage = Stage.IDLE if end_stream else Stage.RESPONSE
|
||||
return ret
|
||||
|
||||
# HEAD request: don't send body
|
||||
if self.head_only:
|
||||
return None
|
||||
|
||||
if self.lifespan is not Lifespan.RESPONSE:
|
||||
if self.stage is not Stage.RESPONSE:
|
||||
if size:
|
||||
raise RuntimeError("Cannot send data to a closed stream")
|
||||
return
|
||||
|
@ -313,7 +312,7 @@ class Http:
|
|||
if self.response_bytes_left is True:
|
||||
if end_stream:
|
||||
self.response_bytes_left = None
|
||||
self.lifespan = Lifespan.IDLE
|
||||
self.stage = Stage.IDLE
|
||||
if size:
|
||||
return b"%x\r\n%b\r\n0\r\n\r\n" % (size, data)
|
||||
return b"0\r\n\r\n"
|
||||
|
@ -325,5 +324,5 @@ class Http:
|
|||
if self.response_bytes_left <= 0:
|
||||
if self.response_bytes_left < 0:
|
||||
raise ServerError("Response was bigger than content-length")
|
||||
self.lifespan = Lifespan.IDLE
|
||||
self.stage = Stage.IDLE
|
||||
return data if size else None
|
||||
|
|
|
@ -47,19 +47,6 @@ class RequestParameters(dict):
|
|||
return super().get(name, default)
|
||||
|
||||
|
||||
class StreamBuffer:
|
||||
def __init__(self, protocol):
|
||||
self.read = protocol.stream_body
|
||||
self.respond = protocol.respond
|
||||
|
||||
async def __aiter__(self):
|
||||
while True:
|
||||
data = await self.read()
|
||||
if not data:
|
||||
return
|
||||
yield data
|
||||
|
||||
|
||||
class Request:
|
||||
"""Properties of an HTTP request such as URL, headers, etc."""
|
||||
|
||||
|
|
168
sanic/server.py
168
sanic/server.py
|
@ -23,7 +23,7 @@ from sanic.exceptions import (
|
|||
ServerError,
|
||||
ServiceUnavailable,
|
||||
)
|
||||
from sanic.http import Http, Lifespan
|
||||
from sanic.http import Http, Stage
|
||||
from sanic.log import access_logger, logger
|
||||
from sanic.request import Request
|
||||
|
||||
|
@ -71,7 +71,6 @@ class HttpProtocol(asyncio.Protocol):
|
|||
# connection management
|
||||
"_total_request_size",
|
||||
"_last_response_time",
|
||||
"keep_alive",
|
||||
"state",
|
||||
"url",
|
||||
"_debug",
|
||||
|
@ -82,6 +81,7 @@ class HttpProtocol(asyncio.Protocol):
|
|||
"_task",
|
||||
"_http",
|
||||
"_exception",
|
||||
"recv_buffer",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -126,7 +126,6 @@ class HttpProtocol(asyncio.Protocol):
|
|||
self.request_max_size = request_max_size
|
||||
self.request_class = request_class or Request
|
||||
self._total_request_size = 0
|
||||
self.keep_alive = keep_alive
|
||||
self.state = state if state else {}
|
||||
if "requests_count" not in self.state:
|
||||
self.state["requests_count"] = 0
|
||||
|
@ -136,30 +135,40 @@ class HttpProtocol(asyncio.Protocol):
|
|||
self._can_write.set()
|
||||
self._exception = None
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Connection
|
||||
# -------------------------------------------- #
|
||||
async def connection_task(self):
|
||||
"""Run a HTTP connection.
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.connections.add(self)
|
||||
self.transport = transport
|
||||
#self.check_timeouts()
|
||||
self._http = Http(self)
|
||||
self._task = self.loop.create_task(self.connection_task())
|
||||
self._time = current_time()
|
||||
self.check_timeouts()
|
||||
|
||||
def connection_lost(self, exc):
|
||||
self.connections.discard(self)
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
Timeouts and some additional error handling occur here, while most of
|
||||
everything else happens in class Http or in code called from there.
|
||||
"""
|
||||
try:
|
||||
self._http = Http(self)
|
||||
self._time = current_time()
|
||||
self.check_timeouts()
|
||||
try:
|
||||
await self._http.http1()
|
||||
except asyncio.CancelledError:
|
||||
await self._http.write_error(
|
||||
self._exception
|
||||
or ServiceUnavailable("Request handler cancelled")
|
||||
)
|
||||
except SanicException as e:
|
||||
await self._http.write_error(e)
|
||||
except BaseException as e:
|
||||
logger.exception(
|
||||
f"Uncaught exception while handling URL {self._http.url}"
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except:
|
||||
logger.exception("protocol.connection_task uncaught")
|
||||
finally:
|
||||
self._http = None
|
||||
self._task = None
|
||||
|
||||
def pause_writing(self):
|
||||
self._can_write.clear()
|
||||
|
||||
def resume_writing(self):
|
||||
self._can_write.set()
|
||||
try:
|
||||
self.close()
|
||||
except:
|
||||
logger.exception("Closing failed")
|
||||
|
||||
async def receive_more(self):
|
||||
"""Wait until more data is received into self._buffer."""
|
||||
|
@ -167,51 +176,18 @@ class HttpProtocol(asyncio.Protocol):
|
|||
self._data_received.clear()
|
||||
await self._data_received.wait()
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Parsing
|
||||
# -------------------------------------------- #
|
||||
|
||||
def data_received(self, data):
|
||||
self._time = current_time()
|
||||
if not data:
|
||||
return self.close()
|
||||
self._http.recv_buffer += data
|
||||
if len(self._http.recv_buffer) > self.request_max_size:
|
||||
self.transport.pause_reading()
|
||||
|
||||
if self._data_received:
|
||||
self._data_received.set()
|
||||
|
||||
async def connection_task(self):
|
||||
try:
|
||||
await self._http.http1()
|
||||
except asyncio.CancelledError:
|
||||
await self._http.write_error(
|
||||
self._exception
|
||||
or ServiceUnavailable("Request handler cancelled")
|
||||
)
|
||||
except SanicException as e:
|
||||
await self._http.write_error(e)
|
||||
except BaseException as e:
|
||||
logger.exception(f"Uncaught exception while handling URL {url}")
|
||||
finally:
|
||||
try:
|
||||
self.close()
|
||||
except:
|
||||
logger.exception("Closing failed")
|
||||
|
||||
def check_timeouts(self):
|
||||
"""Runs itself once a second to enforce any expired timeouts."""
|
||||
if not self._task:
|
||||
return
|
||||
duration = current_time() - self._time
|
||||
lifespan = self._http.lifespan
|
||||
if lifespan == Lifespan.IDLE and duration > self.keep_alive_timeout:
|
||||
stage = self._http.stage
|
||||
if stage is Stage.IDLE and duration > self.keep_alive_timeout:
|
||||
logger.debug("KeepAlive Timeout. Closing connection.")
|
||||
elif lifespan == Lifespan.REQUEST and duration > self.request_timeout:
|
||||
elif stage is Stage.REQUEST and duration > self.request_timeout:
|
||||
self._exception = RequestTimeout("Request Timeout")
|
||||
elif (
|
||||
lifespan.value > Lifespan.REQUEST.value
|
||||
stage in (Stage.REQUEST, Stage.FAILED)
|
||||
and duration > self.response_timeout
|
||||
):
|
||||
self._exception = ServiceUnavailable("Response Timeout")
|
||||
|
@ -220,20 +196,18 @@ class HttpProtocol(asyncio.Protocol):
|
|||
return
|
||||
self._task.cancel()
|
||||
|
||||
async def drain(self):
|
||||
async def send(self, data):
|
||||
"""Writes data with backpressure control."""
|
||||
await self._can_write.wait()
|
||||
|
||||
async def push_data(self, data):
|
||||
self._time = current_time()
|
||||
await self.drain()
|
||||
self.transport.write(data)
|
||||
self._time = current_time()
|
||||
|
||||
def close_if_idle(self):
|
||||
"""Close the connection if a request is not being sent or received
|
||||
|
||||
:return: boolean - True if closed, false if staying open
|
||||
"""
|
||||
if self._http.lifespan == Lifespan.IDLE:
|
||||
if self._http.stage is Stage.IDLE:
|
||||
self.close()
|
||||
return True
|
||||
return False
|
||||
|
@ -242,16 +216,56 @@ class HttpProtocol(asyncio.Protocol):
|
|||
"""
|
||||
Force close the connection.
|
||||
"""
|
||||
if self.transport is not None:
|
||||
try:
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
self.transport.close()
|
||||
self.resume_writing()
|
||||
finally:
|
||||
self.transport = None
|
||||
# Cause a call to connection_lost where further cleanup occurs
|
||||
if self.transport:
|
||||
self.transport.close()
|
||||
self.transport = None
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Only asyncio.Protocol callbacks below this
|
||||
# -------------------------------------------- #
|
||||
|
||||
def connection_made(self, transport):
|
||||
try:
|
||||
# TODO: Benchmark to find suitable write buffer limits
|
||||
transport.set_write_buffer_limits(low=16384, high=65536)
|
||||
self.connections.add(self)
|
||||
self.transport = transport
|
||||
self._task = self.loop.create_task(self.connection_task())
|
||||
self.recv_buffer = bytearray()
|
||||
except:
|
||||
logger.exception("protocol.connect_made")
|
||||
|
||||
def connection_lost(self, exc):
|
||||
try:
|
||||
self.connections.discard(self)
|
||||
self.resume_writing()
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
except:
|
||||
logger.exception("protocol.connection_lost")
|
||||
|
||||
def pause_writing(self):
|
||||
self._can_write.clear()
|
||||
|
||||
def resume_writing(self):
|
||||
self._can_write.set()
|
||||
|
||||
def data_received(self, data):
|
||||
try:
|
||||
self._time = current_time()
|
||||
if not data:
|
||||
return self.close()
|
||||
self.recv_buffer += data
|
||||
|
||||
# Buffer up to request max size
|
||||
if len(self.recv_buffer) > self.request_max_size:
|
||||
self.transport.pause_reading()
|
||||
|
||||
if self._data_received:
|
||||
self._data_received.set()
|
||||
except:
|
||||
logger.exception("protocol.data_received")
|
||||
|
||||
def trigger_events(events, loop):
|
||||
"""Trigger event callbacks (functions or async)
|
||||
|
|
Loading…
Reference in New Issue
Block a user