diff --git a/sanic/app.py b/sanic/app.py index ad4d7923..d2894ff2 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -26,7 +26,7 @@ from sanic.websocket import WebSocketProtocol, ConnectionClosed class Sanic: def __init__(self, name=None, router=None, error_handler=None, - load_env=True): + load_env=True, request_class=None): # Only set up a default log handler if the # end-user application didn't set anything up. if not logging.root.handlers and log.level == logging.NOTSET: @@ -44,6 +44,7 @@ class Sanic: self.name = name self.router = router or Router() + self.request_class = request_class self.error_handler = error_handler or ErrorHandler() self.config = Config(load_env=load_env) self.request_middleware = deque() @@ -668,6 +669,7 @@ class Sanic: server_settings = { 'protocol': protocol, + 'request_class': self.request_class, 'host': host, 'port': port, 'sock': sock, diff --git a/sanic/server.py b/sanic/server.py index 976503a6..04dffefd 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -63,13 +63,13 @@ class HttpProtocol(asyncio.Protocol): # request params 'parser', 'request', 'url', 'headers', # request config - 'request_handler', 'request_timeout', 'request_max_size', + 'request_handler', 'request_timeout', 'request_max_size', 'request_class', # connection management '_total_request_size', '_timeout_handler', '_last_communication_time') def __init__(self, *, loop, request_handler, error_handler, signal=Signal(), connections=set(), request_timeout=60, - request_max_size=None): + request_max_size=None, request_class=None): self.loop = loop self.transport = None self.request = None @@ -82,6 +82,7 @@ class HttpProtocol(asyncio.Protocol): self.error_handler = error_handler self.request_timeout = request_timeout self.request_max_size = request_max_size + self.request_class = request_class or Request self._total_request_size = 0 self._timeout_handler = None self._last_request_time = None @@ -151,7 +152,7 @@ class HttpProtocol(asyncio.Protocol): self.headers.append((name.decode().casefold(), value.decode())) def on_headers_complete(self): - self.request = Request( + self.request = self.request_class( url_bytes=self.url, headers=CIDict(self.headers), version=self.parser.get_http_version(), @@ -320,7 +321,7 @@ def serve(host, port, request_handler, error_handler, before_start=None, request_timeout=60, ssl=None, sock=None, request_max_size=None, reuse_port=False, loop=None, protocol=HttpProtocol, backlog=100, register_sys_signals=True, run_async=False, connections=None, - signal=Signal()): + signal=Signal(), request_class=None): """Start asynchronous HTTP Server on an individual process. :param host: Address to host on @@ -345,6 +346,7 @@ def serve(host, port, request_handler, error_handler, before_start=None, :param reuse_port: `True` for multiple workers :param loop: asyncio compatible event loop :param protocol: subclass of asyncio protocol class + :param request_class: Request class to use :return: Nothing """ if not run_async: @@ -366,6 +368,7 @@ def serve(host, port, request_handler, error_handler, before_start=None, error_handler=error_handler, request_timeout=request_timeout, request_max_size=request_max_size, + request_class=request_class, ) server_coroutine = loop.create_server(