Code cleanup, 14 tests failing.

This commit is contained in:
L. Kärkkäinen 2020-02-29 14:18:31 +02:00
parent 8a1baeb9d5
commit 57202bfa89
4 changed files with 127 additions and 120 deletions

View File

@ -20,7 +20,7 @@ import sanic.app # noqa
from sanic.compat import Header
from sanic.exceptions import InvalidUsage, ServerError
from sanic.log import logger
from sanic.request import Request, StreamBuffer
from sanic.request import Request
from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.websocket import WebSocketConnection
@ -250,11 +250,11 @@ class ASGIApp:
instance.transport,
sanic_app,
)
instance.request.stream = StreamBuffer(protocol=instance)
instance.request.stream = instance
return instance
async def stream_body(self) -> None:
async def read(self) -> None:
"""
Read and stream the body in chunks from an incoming ASGI message.
"""
@ -263,6 +263,13 @@ class ASGIApp:
return None
return message.get("body", b"")
async def __aiter__(self):
while True:
data = await self.read()
if not data:
return
yield data
def respond(self, response):
headers: List[Tuple[bytes, bytes]] = []
cookies: Dict[str, str] = {}

View File

@ -17,7 +17,7 @@ from sanic.response import HTTPResponse
from sanic.compat import Header
class Lifespan(Enum):
class Stage(Enum):
IDLE = 0 # Waiting for request
REQUEST = 1 # Request headers being received
HANDLER = 3 # Headers done, handler running
@ -29,20 +29,21 @@ HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n"
class Http:
def __init__(self, protocol):
self._send = protocol.push_data
self._send = protocol.send
self._receive_more = protocol.receive_more
self.recv_buffer = protocol.recv_buffer
self.protocol = protocol
self.recv_buffer = bytearray()
self.expecting_continue = False
# Note: connections are initially in request mode and do not obey
# keep-alive timeout like with some other servers.
self.lifespan = Lifespan.REQUEST
self.stage = Stage.REQUEST
self.keep_alive = True
self.head_only = None
async def http1(self):
"""HTTP 1.1 connection handler"""
buf = self.recv_buffer
self.keep_alive = True
url = None
self.url = None
while self.keep_alive:
# Read request header
pos = 0
@ -53,17 +54,17 @@ class Http:
break
pos = max(0, len(buf) - 3)
await self._receive_more()
if self.lifespan is Lifespan.IDLE:
self.lifespan = Lifespan.REQUEST
if self.stage is Stage.IDLE:
self.stage = Stage.REQUEST
else:
self.lifespan = Lifespan.HANDLER
self.stage = Stage.HANDLER
raise PayloadTooLarge("Payload Too Large")
self.protocol._total_request_size = pos + 4
try:
reqline, *headers = buf[:pos].decode().split("\r\n")
method, url, protocol = reqline.split(" ")
method, self.url, protocol = reqline.split(" ")
if protocol not in ("HTTP/1.0", "HTTP/1.1"):
raise Exception
self.head_only = method.upper() == "HEAD"
@ -72,12 +73,12 @@ class Http:
for name, value in (h.split(":", 1) for h in headers)
)
except:
self.lifespan = Lifespan.HANDLER
self.stage = Stage.HANDLER
raise InvalidUsage("Bad Request")
# Prepare a request object from the header received
request = self.protocol.request_class(
url_bytes=url.encode(),
url_bytes=self.url.encode(),
headers=headers,
version=protocol[-3:],
method=method,
@ -86,8 +87,6 @@ class Http:
)
request.stream = self
self.protocol.state["requests_count"] += 1
self.protocol.url = url
self.protocol.request = request
self.keep_alive = (
protocol == "HTTP/1.1"
or headers.get("connection", "").lower() == "keep-alive"
@ -98,7 +97,7 @@ class Http:
)
self.request_chunked = False
self.request_bytes_left = 0
self.lifespan = Lifespan.HANDLER
self.stage = Stage.HANDLER
if body:
expect = headers.get("expect")
if expect:
@ -121,14 +120,14 @@ class Http:
except Exception:
logger.exception("Uncaught from app/handler")
await self.write_error(ServerError("Internal Server Error"))
if self.lifespan is Lifespan.IDLE:
if self.stage is Stage.IDLE:
continue
if self.lifespan is Lifespan.HANDLER:
if self.stage is Stage.HANDLER:
await self.respond(HTTPResponse(status=204)).send(end_stream=True)
# Finish sending a response (if no error)
elif self.lifespan is Lifespan.RESPONSE:
elif self.stage is Stage.RESPONSE:
await self.send(end_stream=True)
# Consume any remaining request body
@ -139,10 +138,10 @@ class Http:
while await self.read():
pass
self.lifespan = Lifespan.IDLE
self.stage = Stage.IDLE
async def write_error(self, e):
if self.lifespan is Lifespan.HANDLER:
if self.stage is Stage.HANDLER:
try:
response = HTTPResponse(f"{e}", e.status_code, content_type="text/plain")
await self.respond(response).send(end_stream=True)
@ -210,8 +209,8 @@ class Http:
Nothing is sent until the first send() call on the returned object, and
calling this function multiple times will just alter the response to be
given."""
if self.lifespan is not Lifespan.HANDLER:
self.lifespan = Lifespan.FAILED
if self.stage is not Stage.HANDLER:
self.stage = Stage.FAILED
raise RuntimeError("Response already started")
if not isinstance(response.status, int) or response.status < 200:
raise RuntimeError(f"Invalid response status {response.status!r}")
@ -240,7 +239,7 @@ class Http:
size = len(data) if data is not None else 0
# Headers not yet sent?
if self.lifespan is Lifespan.HANDLER:
if self.stage is Stage.HANDLER:
if self.response.body:
data = self.response.body + data if data else self.response.body
size = len(data)
@ -297,14 +296,14 @@ class Http:
if status != 417:
ret = HTTP_CONTINUE + ret
# Send response
self.lifespan = Lifespan.IDLE if end_stream else Lifespan.RESPONSE
self.stage = Stage.IDLE if end_stream else Stage.RESPONSE
return ret
# HEAD request: don't send body
if self.head_only:
return None
if self.lifespan is not Lifespan.RESPONSE:
if self.stage is not Stage.RESPONSE:
if size:
raise RuntimeError("Cannot send data to a closed stream")
return
@ -313,7 +312,7 @@ class Http:
if self.response_bytes_left is True:
if end_stream:
self.response_bytes_left = None
self.lifespan = Lifespan.IDLE
self.stage = Stage.IDLE
if size:
return b"%x\r\n%b\r\n0\r\n\r\n" % (size, data)
return b"0\r\n\r\n"
@ -325,5 +324,5 @@ class Http:
if self.response_bytes_left <= 0:
if self.response_bytes_left < 0:
raise ServerError("Response was bigger than content-length")
self.lifespan = Lifespan.IDLE
self.stage = Stage.IDLE
return data if size else None

View File

@ -47,19 +47,6 @@ class RequestParameters(dict):
return super().get(name, default)
class StreamBuffer:
def __init__(self, protocol):
self.read = protocol.stream_body
self.respond = protocol.respond
async def __aiter__(self):
while True:
data = await self.read()
if not data:
return
yield data
class Request:
"""Properties of an HTTP request such as URL, headers, etc."""

View File

@ -23,7 +23,7 @@ from sanic.exceptions import (
ServerError,
ServiceUnavailable,
)
from sanic.http import Http, Lifespan
from sanic.http import Http, Stage
from sanic.log import access_logger, logger
from sanic.request import Request
@ -71,7 +71,6 @@ class HttpProtocol(asyncio.Protocol):
# connection management
"_total_request_size",
"_last_response_time",
"keep_alive",
"state",
"url",
"_debug",
@ -82,6 +81,7 @@ class HttpProtocol(asyncio.Protocol):
"_task",
"_http",
"_exception",
"recv_buffer",
)
def __init__(
@ -126,7 +126,6 @@ class HttpProtocol(asyncio.Protocol):
self.request_max_size = request_max_size
self.request_class = request_class or Request
self._total_request_size = 0
self.keep_alive = keep_alive
self.state = state if state else {}
if "requests_count" not in self.state:
self.state["requests_count"] = 0
@ -136,53 +135,16 @@ class HttpProtocol(asyncio.Protocol):
self._can_write.set()
self._exception = None
# -------------------------------------------- #
# Connection
# -------------------------------------------- #
async def connection_task(self):
"""Run a HTTP connection.
def connection_made(self, transport):
self.connections.add(self)
self.transport = transport
#self.check_timeouts()
Timeouts and some additional error handling occur here, while most of
everything else happens in class Http or in code called from there.
"""
try:
self._http = Http(self)
self._task = self.loop.create_task(self.connection_task())
self._time = current_time()
self.check_timeouts()
def connection_lost(self, exc):
self.connections.discard(self)
if self._task:
self._task.cancel()
self._task = None
def pause_writing(self):
self._can_write.clear()
def resume_writing(self):
self._can_write.set()
async def receive_more(self):
"""Wait until more data is received into self._buffer."""
self.transport.resume_reading()
self._data_received.clear()
await self._data_received.wait()
# -------------------------------------------- #
# Parsing
# -------------------------------------------- #
def data_received(self, data):
self._time = current_time()
if not data:
return self.close()
self._http.recv_buffer += data
if len(self._http.recv_buffer) > self.request_max_size:
self.transport.pause_reading()
if self._data_received:
self._data_received.set()
async def connection_task(self):
try:
await self._http.http1()
except asyncio.CancelledError:
@ -193,25 +155,39 @@ class HttpProtocol(asyncio.Protocol):
except SanicException as e:
await self._http.write_error(e)
except BaseException as e:
logger.exception(f"Uncaught exception while handling URL {url}")
logger.exception(
f"Uncaught exception while handling URL {self._http.url}"
)
except asyncio.CancelledError:
pass
except:
logger.exception("protocol.connection_task uncaught")
finally:
self._http = None
self._task = None
try:
self.close()
except:
logger.exception("Closing failed")
async def receive_more(self):
"""Wait until more data is received into self._buffer."""
self.transport.resume_reading()
self._data_received.clear()
await self._data_received.wait()
def check_timeouts(self):
"""Runs itself once a second to enforce any expired timeouts."""
if not self._task:
return
duration = current_time() - self._time
lifespan = self._http.lifespan
if lifespan == Lifespan.IDLE and duration > self.keep_alive_timeout:
stage = self._http.stage
if stage is Stage.IDLE and duration > self.keep_alive_timeout:
logger.debug("KeepAlive Timeout. Closing connection.")
elif lifespan == Lifespan.REQUEST and duration > self.request_timeout:
elif stage is Stage.REQUEST and duration > self.request_timeout:
self._exception = RequestTimeout("Request Timeout")
elif (
lifespan.value > Lifespan.REQUEST.value
stage in (Stage.REQUEST, Stage.FAILED)
and duration > self.response_timeout
):
self._exception = ServiceUnavailable("Response Timeout")
@ -220,20 +196,18 @@ class HttpProtocol(asyncio.Protocol):
return
self._task.cancel()
async def drain(self):
async def send(self, data):
"""Writes data with backpressure control."""
await self._can_write.wait()
async def push_data(self, data):
self._time = current_time()
await self.drain()
self.transport.write(data)
self._time = current_time()
def close_if_idle(self):
"""Close the connection if a request is not being sent or received
:return: boolean - True if closed, false if staying open
"""
if self._http.lifespan == Lifespan.IDLE:
if self._http.stage is Stage.IDLE:
self.close()
return True
return False
@ -242,16 +216,56 @@ class HttpProtocol(asyncio.Protocol):
"""
Force close the connection.
"""
if self.transport is not None:
try:
if self._task:
self._task.cancel()
self._task = None
# Cause a call to connection_lost where further cleanup occurs
if self.transport:
self.transport.close()
self.resume_writing()
finally:
self.transport = None
# -------------------------------------------- #
# Only asyncio.Protocol callbacks below this
# -------------------------------------------- #
def connection_made(self, transport):
try:
# TODO: Benchmark to find suitable write buffer limits
transport.set_write_buffer_limits(low=16384, high=65536)
self.connections.add(self)
self.transport = transport
self._task = self.loop.create_task(self.connection_task())
self.recv_buffer = bytearray()
except:
logger.exception("protocol.connect_made")
def connection_lost(self, exc):
try:
self.connections.discard(self)
self.resume_writing()
if self._task:
self._task.cancel()
except:
logger.exception("protocol.connection_lost")
def pause_writing(self):
self._can_write.clear()
def resume_writing(self):
self._can_write.set()
def data_received(self, data):
try:
self._time = current_time()
if not data:
return self.close()
self.recv_buffer += data
# Buffer up to request max size
if len(self.recv_buffer) > self.request_max_size:
self.transport.pause_reading()
if self._data_received:
self._data_received.set()
except:
logger.exception("protocol.data_received")
def trigger_events(events, loop):
"""Trigger event callbacks (functions or async)