Allow a custom Request class to be passed in to Sonic

Allowing a custom Request class to be defined would enable either a different Request class or a subclass of Request to be used, providing more flexibility.
This commit is contained in:
aryeh 2017-04-09 13:29:21 -04:00 committed by GitHub
parent 52ff2e0e63
commit 2ef8120073
2 changed files with 10 additions and 5 deletions

View File

@ -26,7 +26,7 @@ from sanic.websocket import WebSocketProtocol, ConnectionClosed
class Sanic: class Sanic:
def __init__(self, name=None, router=None, error_handler=None, def __init__(self, name=None, router=None, error_handler=None,
load_env=True): load_env=True, request_class=None):
# Only set up a default log handler if the # Only set up a default log handler if the
# end-user application didn't set anything up. # end-user application didn't set anything up.
if not logging.root.handlers and log.level == logging.NOTSET: if not logging.root.handlers and log.level == logging.NOTSET:
@ -44,6 +44,7 @@ class Sanic:
self.name = name self.name = name
self.router = router or Router() self.router = router or Router()
self.request_class = request_class
self.error_handler = error_handler or ErrorHandler() self.error_handler = error_handler or ErrorHandler()
self.config = Config(load_env=load_env) self.config = Config(load_env=load_env)
self.request_middleware = deque() self.request_middleware = deque()
@ -668,6 +669,7 @@ class Sanic:
server_settings = { server_settings = {
'protocol': protocol, 'protocol': protocol,
'request_class': self.request_class,
'host': host, 'host': host,
'port': port, 'port': port,
'sock': sock, 'sock': sock,

View File

@ -63,13 +63,13 @@ class HttpProtocol(asyncio.Protocol):
# request params # request params
'parser', 'request', 'url', 'headers', 'parser', 'request', 'url', 'headers',
# request config # request config
'request_handler', 'request_timeout', 'request_max_size', 'request_handler', 'request_timeout', 'request_max_size', 'request_class',
# connection management # connection management
'_total_request_size', '_timeout_handler', '_last_communication_time') '_total_request_size', '_timeout_handler', '_last_communication_time')
def __init__(self, *, loop, request_handler, error_handler, def __init__(self, *, loop, request_handler, error_handler,
signal=Signal(), connections=set(), request_timeout=60, signal=Signal(), connections=set(), request_timeout=60,
request_max_size=None): request_max_size=None, request_class=None):
self.loop = loop self.loop = loop
self.transport = None self.transport = None
self.request = None self.request = None
@ -82,6 +82,7 @@ class HttpProtocol(asyncio.Protocol):
self.error_handler = error_handler self.error_handler = error_handler
self.request_timeout = request_timeout self.request_timeout = request_timeout
self.request_max_size = request_max_size self.request_max_size = request_max_size
self.request_class = request_class or Request
self._total_request_size = 0 self._total_request_size = 0
self._timeout_handler = None self._timeout_handler = None
self._last_request_time = None self._last_request_time = None
@ -151,7 +152,7 @@ class HttpProtocol(asyncio.Protocol):
self.headers.append((name.decode().casefold(), value.decode())) self.headers.append((name.decode().casefold(), value.decode()))
def on_headers_complete(self): def on_headers_complete(self):
self.request = Request( self.request = self.request_class(
url_bytes=self.url, url_bytes=self.url,
headers=CIDict(self.headers), headers=CIDict(self.headers),
version=self.parser.get_http_version(), version=self.parser.get_http_version(),
@ -320,7 +321,7 @@ def serve(host, port, request_handler, error_handler, before_start=None,
request_timeout=60, ssl=None, sock=None, request_max_size=None, request_timeout=60, ssl=None, sock=None, request_max_size=None,
reuse_port=False, loop=None, protocol=HttpProtocol, backlog=100, reuse_port=False, loop=None, protocol=HttpProtocol, backlog=100,
register_sys_signals=True, run_async=False, connections=None, register_sys_signals=True, run_async=False, connections=None,
signal=Signal()): signal=Signal(), request_class=None):
"""Start asynchronous HTTP Server on an individual process. """Start asynchronous HTTP Server on an individual process.
:param host: Address to host on :param host: Address to host on
@ -345,6 +346,7 @@ def serve(host, port, request_handler, error_handler, before_start=None,
:param reuse_port: `True` for multiple workers :param reuse_port: `True` for multiple workers
:param loop: asyncio compatible event loop :param loop: asyncio compatible event loop
:param protocol: subclass of asyncio protocol class :param protocol: subclass of asyncio protocol class
:param request_class: Request class to use
:return: Nothing :return: Nothing
""" """
if not run_async: if not run_async:
@ -366,6 +368,7 @@ def serve(host, port, request_handler, error_handler, before_start=None,
error_handler=error_handler, error_handler=error_handler,
request_timeout=request_timeout, request_timeout=request_timeout,
request_max_size=request_max_size, request_max_size=request_max_size,
request_class=request_class,
) )
server_coroutine = loop.create_server( server_coroutine = loop.create_server(