Merged with master

This commit is contained in:
Channel Cat
2016-11-19 18:21:44 -08:00
11 changed files with 183 additions and 27 deletions

View File

@@ -109,8 +109,9 @@ class Blueprint:
# Detect which way this was called, @middleware or @middleware('AT')
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
middleware = args[0]
args = []
return register_middleware(args[0])
return register_middleware(middleware)
else:
return register_middleware

View File

@@ -8,6 +8,12 @@ from ujson import loads as json_loads
from .log import log
DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream"
# HTTP/1.1: https://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1
# > If the media type remains unknown, the recipient SHOULD treat it
# > as type "application/octet-stream"
class RequestParameters(dict):
"""
Hosts a dict with lists as values where get returns the first
@@ -68,14 +74,13 @@ class Request(dict):
@property
def form(self):
if self.parsed_form is None:
self.parsed_form = {}
self.parsed_files = {}
content_type, parameters = parse_header(
self.headers.get('Content-Type', ''))
self.parsed_form = RequestParameters()
self.parsed_files = RequestParameters()
content_type = self.headers.get(
'Content-Type', DEFAULT_HTTP_CONTENT_TYPE)
content_type, parameters = parse_header(content_type)
try:
is_url_encoded = (
content_type == 'application/x-www-form-urlencoded')
if content_type is None or is_url_encoded:
if content_type == 'application/x-www-form-urlencoded':
self.parsed_form = RequestParameters(
parse_qs(self.body.decode('utf-8')))
elif content_type == 'multipart/form-data':
@@ -86,7 +91,6 @@ class Request(dict):
except Exception as e:
log.exception(e)
pass
return self.parsed_form
@property
@@ -128,10 +132,10 @@ def parse_multipart_form(body, boundary):
Parses a request body and returns fields and files
:param body: Bytes request body
:param boundary: Bytes multipart boundary
:return: fields (dict), files (dict)
:return: fields (RequestParameters), files (RequestParameters)
"""
files = {}
fields = {}
files = RequestParameters()
fields = RequestParameters()
form_parts = body.split(boundary)
for form_part in form_parts[1:-1]:
@@ -162,9 +166,16 @@ def parse_multipart_form(body, boundary):
post_data = form_part[line_index:-4]
if file_name or file_type:
files[field_name] = File(
type=file_type, name=file_name, body=post_data)
file = File(type=file_type, name=file_name, body=post_data)
if field_name in files:
files[field_name].append(file)
else:
files[field_name] = [file]
else:
fields[field_name] = post_data.decode('utf-8')
value = post_data.decode('utf-8')
if field_name in fields:
fields[field_name].append(value)
else:
fields[field_name] = [value]
return fields, files

View File

@@ -1,7 +1,7 @@
from asyncio import get_event_loop
from collections import deque
from functools import partial
from inspect import isawaitable
from inspect import isawaitable, stack, getmodulename
from multiprocessing import Process, Event
from signal import signal, SIGTERM, SIGINT
from time import sleep
@@ -18,7 +18,10 @@ from .exceptions import ServerError
class Sanic:
def __init__(self, name, router=None, error_handler=None):
def __init__(self, name=None, router=None, error_handler=None):
if name is None:
frame_records = stack()[1]
name = getmodulename(frame_records[1])
self.name = name
self.router = router or Router()
self.error_handler = error_handler or Handler(self)

View File

@@ -1,7 +1,9 @@
import asyncio
from functools import partial
from inspect import isawaitable
from signal import SIGINT, SIGTERM
from time import time
from aiohttp import CIMultiDict
import httptools
try:
@@ -17,6 +19,9 @@ class Signal:
stopped = False
current_time = None
class HttpProtocol(asyncio.Protocol):
__slots__ = (
# event loop, connection
@@ -26,7 +31,7 @@ class HttpProtocol(asyncio.Protocol):
# request config
'request_handler', 'request_timeout', 'request_max_size',
# connection management
'_total_request_size', '_timeout_handler')
'_total_request_size', '_timeout_handler', '_last_communication_time')
def __init__(self, *, loop, request_handler, signal=Signal(),
connections={}, request_timeout=60,
@@ -44,6 +49,7 @@ class HttpProtocol(asyncio.Protocol):
self.request_max_size = request_max_size
self._total_request_size = 0
self._timeout_handler = None
self._last_request_time = None
# -------------------------------------------- #
# Connection
@@ -54,6 +60,7 @@ class HttpProtocol(asyncio.Protocol):
self._timeout_handler = self.loop.call_later(
self.request_timeout, self.connection_timeout)
self.transport = transport
self._last_request_time = current_time
def connection_lost(self, exc):
del self.connections[self]
@@ -61,7 +68,14 @@ class HttpProtocol(asyncio.Protocol):
self.cleanup()
def connection_timeout(self):
self.bail_out("Request timed out, connection closed")
# Check if
time_elapsed = current_time - self._last_request_time
if time_elapsed < self.request_timeout:
time_left = self.request_timeout - time_elapsed
self._timeout_handler = \
self.loop.call_later(time_left, self.connection_timeout)
else:
self.bail_out("Request timed out, connection closed")
# -------------------------------------------- #
# Parsing
@@ -100,9 +114,13 @@ class HttpProtocol(asyncio.Protocol):
self.headers.append((name.decode(), value.decode('utf-8')))
def on_headers_complete(self):
remote_addr = self.transport.get_extra_info('peername')
if remote_addr:
self.headers.append(('Remote-Addr', '%s:%s' % remote_addr))
self.request = Request(
url_bytes=self.url,
headers=dict(self.headers),
headers=CIMultiDict(self.headers),
version=self.parser.get_http_version(),
method=self.parser.get_method().decode()
)
@@ -131,13 +149,15 @@ class HttpProtocol(asyncio.Protocol):
if not keep_alive:
self.transport.close()
else:
# Record that we received data
self._last_request_time = current_time
self.cleanup()
except Exception as e:
self.bail_out(
"Writing request failed, connection closed {}".format(e))
def bail_out(self, message):
log.error(message)
log.debug(message)
self.transport.close()
def cleanup(self):
@@ -158,6 +178,18 @@ class HttpProtocol(asyncio.Protocol):
return False
def update_current_time(loop):
"""
Caches the current time, since it is needed
at the end of every keep-alive request to update the request timeout time
:param loop:
:return:
"""
global current_time
current_time = time()
loop.call_later(1, partial(update_current_time, loop))
def trigger_events(events, loop):
"""
:param events: one or more sync or async functions to execute
@@ -212,6 +244,10 @@ def serve(host, port, request_handler, before_start=None, after_start=None,
request_max_size=request_max_size,
), host, port, reuse_port=reuse_port, sock=sock)
# Instead of pulling time at the end of every request,
# pull it once per minute
loop.call_soon(partial(update_current_time, loop))
try:
http_server = loop.run_until_complete(server_coroutine)
except Exception: