Added tests and middleware, and improved documentation
This commit is contained in:
4
sanic/middleware.py
Normal file
4
sanic/middleware.py
Normal file
@@ -0,0 +1,4 @@
|
||||
class Middleware:
|
||||
def __init__(self, process_request=None, process_response=None):
|
||||
self.process_request = process_request
|
||||
self.process_response = process_response
|
||||
@@ -1,7 +1,11 @@
|
||||
from cgi import parse_header
|
||||
from collections import namedtuple
|
||||
from httptools import parse_url
|
||||
from urllib.parse import parse_qs
|
||||
from ujson import loads as json_loads
|
||||
|
||||
from .log import log
|
||||
|
||||
class RequestParameters(dict):
|
||||
"""
|
||||
Hosts a dict with lists as values where get returns the first
|
||||
@@ -20,7 +24,7 @@ class Request:
|
||||
__slots__ = (
|
||||
'url', 'headers', 'version', 'method',
|
||||
'query_string', 'body',
|
||||
'parsed_json', 'parsed_args', 'parsed_form',
|
||||
'parsed_json', 'parsed_args', 'parsed_form', 'parsed_files',
|
||||
)
|
||||
|
||||
def __init__(self, url_bytes, headers, version, method):
|
||||
@@ -36,6 +40,7 @@ class Request:
|
||||
self.body = None
|
||||
self.parsed_json = None
|
||||
self.parsed_form = None
|
||||
self.parsed_files = None
|
||||
self.parsed_args = None
|
||||
|
||||
@property
|
||||
@@ -50,17 +55,30 @@ class Request:
|
||||
|
||||
@property
|
||||
def form(self):
|
||||
if not self.parsed_form:
|
||||
content_type = self.headers.get('Content-Type')
|
||||
if self.parsed_form is None:
|
||||
self.parsed_form = {}
|
||||
self.parsed_files = {}
|
||||
content_type, parameters = parse_header(self.headers.get('Content-Type'))
|
||||
try:
|
||||
# TODO: form-data
|
||||
if content_type is None or content_type == 'application/x-www-form-urlencoded':
|
||||
self.parsed_form = RequestParameters(parse_qs(self.body.decode('utf-8')))
|
||||
except:
|
||||
elif content_type == 'multipart/form-data':
|
||||
# TODO: Stream this instead of reading to/from memory
|
||||
boundary = parameters['boundary'].encode('utf-8')
|
||||
self.parsed_form, self.parsed_files = parse_multipart_form(self.body, boundary)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
pass
|
||||
|
||||
return self.parsed_form
|
||||
|
||||
@property
|
||||
def files(self):
|
||||
if self.parsed_files is None:
|
||||
_ = self.form # compute form to get files
|
||||
|
||||
return self.parsed_files
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
if self.parsed_args is None:
|
||||
@@ -70,3 +88,49 @@ class Request:
|
||||
self.parsed_args = {}
|
||||
|
||||
return self.parsed_args
|
||||
|
||||
File = namedtuple('File', ['type', 'body', 'name'])
|
||||
def parse_multipart_form(body, boundary):
|
||||
"""
|
||||
Parses a request body and returns fields and files
|
||||
:param body: Bytes request body
|
||||
:param boundary: Bytes multipart boundary
|
||||
:return: fields (dict), files (dict)
|
||||
"""
|
||||
files = {}
|
||||
fields = {}
|
||||
|
||||
form_parts = body.split(boundary)
|
||||
for form_part in form_parts[1:-1]:
|
||||
file_name = None
|
||||
file_type = None
|
||||
field_name = None
|
||||
line_index = 2
|
||||
line_end_index = 0
|
||||
while not line_end_index == -1:
|
||||
line_end_index = form_part.find(b'\r\n', line_index)
|
||||
form_line = form_part[line_index:line_end_index].decode('utf-8')
|
||||
line_index = line_end_index + 2
|
||||
|
||||
if not form_line:
|
||||
break
|
||||
|
||||
colon_index = form_line.index(':')
|
||||
form_header_field = form_line[0:colon_index]
|
||||
form_header_value, form_parameters = parse_header(form_line[colon_index+2:])
|
||||
|
||||
if form_header_field == 'Content-Disposition':
|
||||
if 'filename' in form_parameters:
|
||||
file_name = form_parameters['filename']
|
||||
field_name = form_parameters.get('name')
|
||||
elif form_header_field == 'Content-Type':
|
||||
file_type = form_header_value
|
||||
|
||||
|
||||
post_data = form_part[line_index:-4]
|
||||
if file_name or file_type:
|
||||
files[field_name] = File(type=file_type, name=file_name, body=post_data)
|
||||
else:
|
||||
fields[field_name] = post_data.decode('utf-8')
|
||||
|
||||
return fields, files
|
||||
@@ -53,8 +53,8 @@ class HTTPResponse:
|
||||
])
|
||||
|
||||
def json(body, status=200, headers=None):
|
||||
return HTTPResponse(ujson.dumps(body), headers=headers, status=status, content_type="application/json")
|
||||
return HTTPResponse(ujson.dumps(body), headers=headers, status=status, content_type="application/json; charset=utf-8")
|
||||
def text(body, status=200, headers=None):
|
||||
return HTTPResponse(body, status=status, headers=headers, content_type="text/plain")
|
||||
return HTTPResponse(body, status=status, headers=headers, content_type="text/plain; charset=utf-8")
|
||||
def html(body, status=200, headers=None):
|
||||
return HTTPResponse(body, status=status, headers=headers, content_type="text/html")
|
||||
return HTTPResponse(body, status=status, headers=headers, content_type="text/html; charset=utf-8")
|
||||
@@ -1,30 +1,32 @@
|
||||
import asyncio
|
||||
from inspect import isawaitable
|
||||
from traceback import format_exc
|
||||
from types import FunctionType
|
||||
|
||||
from .config import Config
|
||||
from .exceptions import Handler
|
||||
from .log import log, logging
|
||||
from .middleware import Middleware
|
||||
from .response import HTTPResponse
|
||||
from .router import Router
|
||||
from .server import serve
|
||||
from .exceptions import ServerError
|
||||
from inspect import isawaitable
|
||||
from traceback import format_exc
|
||||
|
||||
class Sanic:
|
||||
name = None
|
||||
debug = None
|
||||
router = None
|
||||
error_handler = None
|
||||
routes = []
|
||||
|
||||
def __init__(self, name, router=None, error_handler=None):
|
||||
self.name = name
|
||||
self.router = router or Router()
|
||||
self.router = router or Router()
|
||||
self.error_handler = error_handler or Handler(self)
|
||||
self.config = Config()
|
||||
self.request_middleware = []
|
||||
self.response_middleware = []
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# Decorators
|
||||
# Registration
|
||||
# -------------------------------------------------------------------- #
|
||||
|
||||
# Decorator
|
||||
def route(self, uri, methods=None):
|
||||
"""
|
||||
Decorates a function to be registered as a route
|
||||
@@ -38,6 +40,7 @@ class Sanic:
|
||||
|
||||
return response
|
||||
|
||||
# Decorator
|
||||
def exception(self, *exceptions):
|
||||
"""
|
||||
Decorates a function to be registered as a route
|
||||
@@ -52,6 +55,34 @@ class Sanic:
|
||||
|
||||
return response
|
||||
|
||||
# Decorator
|
||||
def middleware(self, *args, **kwargs):
|
||||
"""
|
||||
Decorates and registers middleware to be called before a request
|
||||
can either be called as @app.middleware or @app.middleware('request')
|
||||
"""
|
||||
middleware = None
|
||||
attach_to = 'request'
|
||||
def register_middleware(middleware):
|
||||
if attach_to == 'request':
|
||||
self.request_middleware.append(middleware)
|
||||
if attach_to == 'response':
|
||||
self.response_middleware.append(middleware)
|
||||
return middleware
|
||||
|
||||
# Detect which way this was called, @middleware or @middleware('AT')
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
return register_middleware(args[0])
|
||||
else:
|
||||
attach_to = args[0]
|
||||
log.info(attach_to)
|
||||
return register_middleware
|
||||
|
||||
if isinstance(middleware, FunctionType):
|
||||
middleware = Middleware(process_request=middleware)
|
||||
|
||||
return middleware
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# Request Handling
|
||||
# -------------------------------------------------------------------- #
|
||||
@@ -65,13 +96,35 @@ class Sanic:
|
||||
:return: Nothing
|
||||
"""
|
||||
try:
|
||||
handler, args, kwargs = self.router.get(request)
|
||||
if handler is None:
|
||||
raise ServerError("'None' was returned while requesting a handler from the router")
|
||||
# Middleware process_request
|
||||
response = None
|
||||
for middleware in self.request_middleware:
|
||||
response = middleware(request)
|
||||
if isawaitable(response):
|
||||
response = await response
|
||||
if response is not None:
|
||||
break
|
||||
|
||||
response = handler(request, *args, **kwargs)
|
||||
if isawaitable(response):
|
||||
response = await response
|
||||
# No middleware results
|
||||
if response is None:
|
||||
# Fetch handler from router
|
||||
handler, args, kwargs = self.router.get(request)
|
||||
if handler is None:
|
||||
raise ServerError("'None' was returned while requesting a handler from the router")
|
||||
|
||||
# Run response handler
|
||||
response = handler(request, *args, **kwargs)
|
||||
if isawaitable(response):
|
||||
response = await response
|
||||
|
||||
# Middleware process_response
|
||||
for middleware in self.response_middleware:
|
||||
_response = middleware(request, response)
|
||||
if isawaitable(_response):
|
||||
_response = await _response
|
||||
if _response is not None:
|
||||
response = _response
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
@@ -90,14 +143,14 @@ class Sanic:
|
||||
# Execution
|
||||
# -------------------------------------------------------------------- #
|
||||
|
||||
def run(self, host="127.0.0.1", port=8000, debug=False, before_start=None, before_stop=None):
|
||||
def run(self, host="127.0.0.1", port=8000, debug=False, after_start=None, before_stop=None):
|
||||
"""
|
||||
Runs the HTTP Server and listens until keyboard interrupt or term signal.
|
||||
On termination, drains connections before closing.
|
||||
:param host: Address to host on
|
||||
:param port: Port to host on
|
||||
:param debug: Enables debug output (slows server)
|
||||
:param before_start: Function to be executed after the event loop is created and before the server starts
|
||||
:param after_start: Function to be executed after the server starts listening
|
||||
:param before_stop: Function to be executed when a stop signal is received before it is respected
|
||||
:return: Nothing
|
||||
"""
|
||||
@@ -116,11 +169,17 @@ class Sanic:
|
||||
host=host,
|
||||
port=port,
|
||||
debug=debug,
|
||||
before_start=before_start,
|
||||
after_start=after_start,
|
||||
before_stop=before_stop,
|
||||
request_handler=self.handle_request,
|
||||
request_timeout=self.config.REQUEST_TIMEOUT,
|
||||
request_max_size=self.config.REQUEST_MAX_SIZE,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
This kills the Sanic
|
||||
"""
|
||||
asyncio.get_event_loop().stop()
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from inspect import isawaitable
|
||||
from signal import SIGINT, SIGTERM
|
||||
|
||||
import httptools
|
||||
try:
|
||||
@@ -132,17 +133,13 @@ class HttpProtocol(asyncio.Protocol):
|
||||
return True
|
||||
return False
|
||||
|
||||
def serve(host, port, request_handler, before_start=None, before_stop=None, debug=False, request_timeout=60, request_max_size=None):
|
||||
def serve(host, port, request_handler, after_start=None, before_stop=None, debug=False, request_timeout=60, request_max_size=None):
|
||||
# Create Event Loop
|
||||
loop = async_loop.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.set_debug(debug)
|
||||
|
||||
# Run the on_start function if provided
|
||||
if before_start:
|
||||
result = before_start(loop)
|
||||
if isawaitable(result):
|
||||
loop.run_until_complete(result)
|
||||
# I don't think we take advantage of this
|
||||
# And it slows everything waaayyy down
|
||||
#loop.set_debug(debug)
|
||||
|
||||
connections = {}
|
||||
signal = Signal()
|
||||
@@ -156,10 +153,18 @@ def serve(host, port, request_handler, before_start=None, before_stop=None, debu
|
||||
), host, port)
|
||||
http_server = loop.run_until_complete(server_coroutine)
|
||||
|
||||
# Run the on_start function if provided
|
||||
if after_start:
|
||||
result = after_start(loop)
|
||||
if isawaitable(result):
|
||||
loop.run_until_complete(result)
|
||||
|
||||
# Register signals for graceful termination
|
||||
for _signal in (SIGINT, SIGTERM):
|
||||
loop.add_signal_handler(_signal, loop.stop)
|
||||
|
||||
try:
|
||||
loop.run_forever()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
log.info("Stop requested, draining connections...")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user