Code cleanup, 14 tests failing.

This commit is contained in:
L. Kärkkäinen 2020-02-29 14:18:31 +02:00
parent 8a1baeb9d5
commit 57202bfa89
4 changed files with 127 additions and 120 deletions

View File

@ -20,7 +20,7 @@ import sanic.app # noqa
from sanic.compat import Header from sanic.compat import Header
from sanic.exceptions import InvalidUsage, ServerError from sanic.exceptions import InvalidUsage, ServerError
from sanic.log import logger from sanic.log import logger
from sanic.request import Request, StreamBuffer from sanic.request import Request
from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.websocket import WebSocketConnection from sanic.websocket import WebSocketConnection
@ -250,11 +250,11 @@ class ASGIApp:
instance.transport, instance.transport,
sanic_app, sanic_app,
) )
instance.request.stream = StreamBuffer(protocol=instance) instance.request.stream = instance
return instance return instance
async def stream_body(self) -> None: async def read(self) -> None:
""" """
Read and stream the body in chunks from an incoming ASGI message. Read and stream the body in chunks from an incoming ASGI message.
""" """
@ -263,6 +263,13 @@ class ASGIApp:
return None return None
return message.get("body", b"") return message.get("body", b"")
async def __aiter__(self):
while True:
data = await self.read()
if not data:
return
yield data
def respond(self, response): def respond(self, response):
headers: List[Tuple[bytes, bytes]] = [] headers: List[Tuple[bytes, bytes]] = []
cookies: Dict[str, str] = {} cookies: Dict[str, str] = {}

View File

@ -17,7 +17,7 @@ from sanic.response import HTTPResponse
from sanic.compat import Header from sanic.compat import Header
class Lifespan(Enum): class Stage(Enum):
IDLE = 0 # Waiting for request IDLE = 0 # Waiting for request
REQUEST = 1 # Request headers being received REQUEST = 1 # Request headers being received
HANDLER = 3 # Headers done, handler running HANDLER = 3 # Headers done, handler running
@ -29,20 +29,21 @@ HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n"
class Http: class Http:
def __init__(self, protocol): def __init__(self, protocol):
self._send = protocol.push_data self._send = protocol.send
self._receive_more = protocol.receive_more self._receive_more = protocol.receive_more
self.recv_buffer = protocol.recv_buffer
self.protocol = protocol self.protocol = protocol
self.recv_buffer = bytearray()
self.expecting_continue = False self.expecting_continue = False
# Note: connections are initially in request mode and do not obey # Note: connections are initially in request mode and do not obey
# keep-alive timeout like with some other servers. # keep-alive timeout like with some other servers.
self.lifespan = Lifespan.REQUEST self.stage = Stage.REQUEST
self.keep_alive = True
self.head_only = None
async def http1(self): async def http1(self):
"""HTTP 1.1 connection handler""" """HTTP 1.1 connection handler"""
buf = self.recv_buffer buf = self.recv_buffer
self.keep_alive = True self.url = None
url = None
while self.keep_alive: while self.keep_alive:
# Read request header # Read request header
pos = 0 pos = 0
@ -53,17 +54,17 @@ class Http:
break break
pos = max(0, len(buf) - 3) pos = max(0, len(buf) - 3)
await self._receive_more() await self._receive_more()
if self.lifespan is Lifespan.IDLE: if self.stage is Stage.IDLE:
self.lifespan = Lifespan.REQUEST self.stage = Stage.REQUEST
else: else:
self.lifespan = Lifespan.HANDLER self.stage = Stage.HANDLER
raise PayloadTooLarge("Payload Too Large") raise PayloadTooLarge("Payload Too Large")
self.protocol._total_request_size = pos + 4 self.protocol._total_request_size = pos + 4
try: try:
reqline, *headers = buf[:pos].decode().split("\r\n") reqline, *headers = buf[:pos].decode().split("\r\n")
method, url, protocol = reqline.split(" ") method, self.url, protocol = reqline.split(" ")
if protocol not in ("HTTP/1.0", "HTTP/1.1"): if protocol not in ("HTTP/1.0", "HTTP/1.1"):
raise Exception raise Exception
self.head_only = method.upper() == "HEAD" self.head_only = method.upper() == "HEAD"
@ -72,12 +73,12 @@ class Http:
for name, value in (h.split(":", 1) for h in headers) for name, value in (h.split(":", 1) for h in headers)
) )
except: except:
self.lifespan = Lifespan.HANDLER self.stage = Stage.HANDLER
raise InvalidUsage("Bad Request") raise InvalidUsage("Bad Request")
# Prepare a request object from the header received # Prepare a request object from the header received
request = self.protocol.request_class( request = self.protocol.request_class(
url_bytes=url.encode(), url_bytes=self.url.encode(),
headers=headers, headers=headers,
version=protocol[-3:], version=protocol[-3:],
method=method, method=method,
@ -86,8 +87,6 @@ class Http:
) )
request.stream = self request.stream = self
self.protocol.state["requests_count"] += 1 self.protocol.state["requests_count"] += 1
self.protocol.url = url
self.protocol.request = request
self.keep_alive = ( self.keep_alive = (
protocol == "HTTP/1.1" protocol == "HTTP/1.1"
or headers.get("connection", "").lower() == "keep-alive" or headers.get("connection", "").lower() == "keep-alive"
@ -98,7 +97,7 @@ class Http:
) )
self.request_chunked = False self.request_chunked = False
self.request_bytes_left = 0 self.request_bytes_left = 0
self.lifespan = Lifespan.HANDLER self.stage = Stage.HANDLER
if body: if body:
expect = headers.get("expect") expect = headers.get("expect")
if expect: if expect:
@ -121,14 +120,14 @@ class Http:
except Exception: except Exception:
logger.exception("Uncaught from app/handler") logger.exception("Uncaught from app/handler")
await self.write_error(ServerError("Internal Server Error")) await self.write_error(ServerError("Internal Server Error"))
if self.lifespan is Lifespan.IDLE: if self.stage is Stage.IDLE:
continue continue
if self.lifespan is Lifespan.HANDLER: if self.stage is Stage.HANDLER:
await self.respond(HTTPResponse(status=204)).send(end_stream=True) await self.respond(HTTPResponse(status=204)).send(end_stream=True)
# Finish sending a response (if no error) # Finish sending a response (if no error)
elif self.lifespan is Lifespan.RESPONSE: elif self.stage is Stage.RESPONSE:
await self.send(end_stream=True) await self.send(end_stream=True)
# Consume any remaining request body # Consume any remaining request body
@ -139,10 +138,10 @@ class Http:
while await self.read(): while await self.read():
pass pass
self.lifespan = Lifespan.IDLE self.stage = Stage.IDLE
async def write_error(self, e): async def write_error(self, e):
if self.lifespan is Lifespan.HANDLER: if self.stage is Stage.HANDLER:
try: try:
response = HTTPResponse(f"{e}", e.status_code, content_type="text/plain") response = HTTPResponse(f"{e}", e.status_code, content_type="text/plain")
await self.respond(response).send(end_stream=True) await self.respond(response).send(end_stream=True)
@ -210,8 +209,8 @@ class Http:
Nothing is sent until the first send() call on the returned object, and Nothing is sent until the first send() call on the returned object, and
calling this function multiple times will just alter the response to be calling this function multiple times will just alter the response to be
given.""" given."""
if self.lifespan is not Lifespan.HANDLER: if self.stage is not Stage.HANDLER:
self.lifespan = Lifespan.FAILED self.stage = Stage.FAILED
raise RuntimeError("Response already started") raise RuntimeError("Response already started")
if not isinstance(response.status, int) or response.status < 200: if not isinstance(response.status, int) or response.status < 200:
raise RuntimeError(f"Invalid response status {response.status!r}") raise RuntimeError(f"Invalid response status {response.status!r}")
@ -240,7 +239,7 @@ class Http:
size = len(data) if data is not None else 0 size = len(data) if data is not None else 0
# Headers not yet sent? # Headers not yet sent?
if self.lifespan is Lifespan.HANDLER: if self.stage is Stage.HANDLER:
if self.response.body: if self.response.body:
data = self.response.body + data if data else self.response.body data = self.response.body + data if data else self.response.body
size = len(data) size = len(data)
@ -297,14 +296,14 @@ class Http:
if status != 417: if status != 417:
ret = HTTP_CONTINUE + ret ret = HTTP_CONTINUE + ret
# Send response # Send response
self.lifespan = Lifespan.IDLE if end_stream else Lifespan.RESPONSE self.stage = Stage.IDLE if end_stream else Stage.RESPONSE
return ret return ret
# HEAD request: don't send body # HEAD request: don't send body
if self.head_only: if self.head_only:
return None return None
if self.lifespan is not Lifespan.RESPONSE: if self.stage is not Stage.RESPONSE:
if size: if size:
raise RuntimeError("Cannot send data to a closed stream") raise RuntimeError("Cannot send data to a closed stream")
return return
@ -313,7 +312,7 @@ class Http:
if self.response_bytes_left is True: if self.response_bytes_left is True:
if end_stream: if end_stream:
self.response_bytes_left = None self.response_bytes_left = None
self.lifespan = Lifespan.IDLE self.stage = Stage.IDLE
if size: if size:
return b"%x\r\n%b\r\n0\r\n\r\n" % (size, data) return b"%x\r\n%b\r\n0\r\n\r\n" % (size, data)
return b"0\r\n\r\n" return b"0\r\n\r\n"
@ -325,5 +324,5 @@ class Http:
if self.response_bytes_left <= 0: if self.response_bytes_left <= 0:
if self.response_bytes_left < 0: if self.response_bytes_left < 0:
raise ServerError("Response was bigger than content-length") raise ServerError("Response was bigger than content-length")
self.lifespan = Lifespan.IDLE self.stage = Stage.IDLE
return data if size else None return data if size else None

View File

@ -47,19 +47,6 @@ class RequestParameters(dict):
return super().get(name, default) return super().get(name, default)
class StreamBuffer:
def __init__(self, protocol):
self.read = protocol.stream_body
self.respond = protocol.respond
async def __aiter__(self):
while True:
data = await self.read()
if not data:
return
yield data
class Request: class Request:
"""Properties of an HTTP request such as URL, headers, etc.""" """Properties of an HTTP request such as URL, headers, etc."""

View File

@ -23,7 +23,7 @@ from sanic.exceptions import (
ServerError, ServerError,
ServiceUnavailable, ServiceUnavailable,
) )
from sanic.http import Http, Lifespan from sanic.http import Http, Stage
from sanic.log import access_logger, logger from sanic.log import access_logger, logger
from sanic.request import Request from sanic.request import Request
@ -71,7 +71,6 @@ class HttpProtocol(asyncio.Protocol):
# connection management # connection management
"_total_request_size", "_total_request_size",
"_last_response_time", "_last_response_time",
"keep_alive",
"state", "state",
"url", "url",
"_debug", "_debug",
@ -82,6 +81,7 @@ class HttpProtocol(asyncio.Protocol):
"_task", "_task",
"_http", "_http",
"_exception", "_exception",
"recv_buffer",
) )
def __init__( def __init__(
@ -126,7 +126,6 @@ class HttpProtocol(asyncio.Protocol):
self.request_max_size = request_max_size self.request_max_size = request_max_size
self.request_class = request_class or Request self.request_class = request_class or Request
self._total_request_size = 0 self._total_request_size = 0
self.keep_alive = keep_alive
self.state = state if state else {} self.state = state if state else {}
if "requests_count" not in self.state: if "requests_count" not in self.state:
self.state["requests_count"] = 0 self.state["requests_count"] = 0
@ -136,30 +135,40 @@ class HttpProtocol(asyncio.Protocol):
self._can_write.set() self._can_write.set()
self._exception = None self._exception = None
# -------------------------------------------- # async def connection_task(self):
# Connection """Run a HTTP connection.
# -------------------------------------------- #
def connection_made(self, transport): Timeouts and some additional error handling occur here, while most of
self.connections.add(self) everything else happens in class Http or in code called from there.
self.transport = transport """
#self.check_timeouts() try:
self._http = Http(self) self._http = Http(self)
self._task = self.loop.create_task(self.connection_task()) self._time = current_time()
self._time = current_time() self.check_timeouts()
self.check_timeouts() try:
await self._http.http1()
def connection_lost(self, exc): except asyncio.CancelledError:
self.connections.discard(self) await self._http.write_error(
if self._task: self._exception
self._task.cancel() or ServiceUnavailable("Request handler cancelled")
)
except SanicException as e:
await self._http.write_error(e)
except BaseException as e:
logger.exception(
f"Uncaught exception while handling URL {self._http.url}"
)
except asyncio.CancelledError:
pass
except:
logger.exception("protocol.connection_task uncaught")
finally:
self._http = None
self._task = None self._task = None
try:
def pause_writing(self): self.close()
self._can_write.clear() except:
logger.exception("Closing failed")
def resume_writing(self):
self._can_write.set()
async def receive_more(self): async def receive_more(self):
"""Wait until more data is received into self._buffer.""" """Wait until more data is received into self._buffer."""
@ -167,51 +176,18 @@ class HttpProtocol(asyncio.Protocol):
self._data_received.clear() self._data_received.clear()
await self._data_received.wait() await self._data_received.wait()
# -------------------------------------------- #
# Parsing
# -------------------------------------------- #
def data_received(self, data):
self._time = current_time()
if not data:
return self.close()
self._http.recv_buffer += data
if len(self._http.recv_buffer) > self.request_max_size:
self.transport.pause_reading()
if self._data_received:
self._data_received.set()
async def connection_task(self):
try:
await self._http.http1()
except asyncio.CancelledError:
await self._http.write_error(
self._exception
or ServiceUnavailable("Request handler cancelled")
)
except SanicException as e:
await self._http.write_error(e)
except BaseException as e:
logger.exception(f"Uncaught exception while handling URL {url}")
finally:
try:
self.close()
except:
logger.exception("Closing failed")
def check_timeouts(self): def check_timeouts(self):
"""Runs itself once a second to enforce any expired timeouts.""" """Runs itself once a second to enforce any expired timeouts."""
if not self._task: if not self._task:
return return
duration = current_time() - self._time duration = current_time() - self._time
lifespan = self._http.lifespan stage = self._http.stage
if lifespan == Lifespan.IDLE and duration > self.keep_alive_timeout: if stage is Stage.IDLE and duration > self.keep_alive_timeout:
logger.debug("KeepAlive Timeout. Closing connection.") logger.debug("KeepAlive Timeout. Closing connection.")
elif lifespan == Lifespan.REQUEST and duration > self.request_timeout: elif stage is Stage.REQUEST and duration > self.request_timeout:
self._exception = RequestTimeout("Request Timeout") self._exception = RequestTimeout("Request Timeout")
elif ( elif (
lifespan.value > Lifespan.REQUEST.value stage in (Stage.REQUEST, Stage.FAILED)
and duration > self.response_timeout and duration > self.response_timeout
): ):
self._exception = ServiceUnavailable("Response Timeout") self._exception = ServiceUnavailable("Response Timeout")
@ -220,20 +196,18 @@ class HttpProtocol(asyncio.Protocol):
return return
self._task.cancel() self._task.cancel()
async def drain(self): async def send(self, data):
"""Writes data with backpressure control."""
await self._can_write.wait() await self._can_write.wait()
async def push_data(self, data):
self._time = current_time()
await self.drain()
self.transport.write(data) self.transport.write(data)
self._time = current_time()
def close_if_idle(self): def close_if_idle(self):
"""Close the connection if a request is not being sent or received """Close the connection if a request is not being sent or received
:return: boolean - True if closed, false if staying open :return: boolean - True if closed, false if staying open
""" """
if self._http.lifespan == Lifespan.IDLE: if self._http.stage is Stage.IDLE:
self.close() self.close()
return True return True
return False return False
@ -242,16 +216,56 @@ class HttpProtocol(asyncio.Protocol):
""" """
Force close the connection. Force close the connection.
""" """
if self.transport is not None: # Cause a call to connection_lost where further cleanup occurs
try: if self.transport:
if self._task: self.transport.close()
self._task.cancel() self.transport = None
self._task = None
self.transport.close()
self.resume_writing()
finally:
self.transport = None
# -------------------------------------------- #
# Only asyncio.Protocol callbacks below this
# -------------------------------------------- #
def connection_made(self, transport):
try:
# TODO: Benchmark to find suitable write buffer limits
transport.set_write_buffer_limits(low=16384, high=65536)
self.connections.add(self)
self.transport = transport
self._task = self.loop.create_task(self.connection_task())
self.recv_buffer = bytearray()
except:
logger.exception("protocol.connect_made")
def connection_lost(self, exc):
try:
self.connections.discard(self)
self.resume_writing()
if self._task:
self._task.cancel()
except:
logger.exception("protocol.connection_lost")
def pause_writing(self):
self._can_write.clear()
def resume_writing(self):
self._can_write.set()
def data_received(self, data):
try:
self._time = current_time()
if not data:
return self.close()
self.recv_buffer += data
# Buffer up to request max size
if len(self.recv_buffer) > self.request_max_size:
self.transport.pause_reading()
if self._data_received:
self._data_received.set()
except:
logger.exception("protocol.data_received")
def trigger_events(events, loop): def trigger_events(events, loop):
"""Trigger event callbacks (functions or async) """Trigger event callbacks (functions or async)