Merge pull request #128 from channelcat/keep-alive-timeout-fix

Keep alive timeout fix
This commit is contained in:
Eli Uriegas 2016-11-05 12:11:41 -06:00 committed by GitHub
commit 1271c7d958

View File

@ -1,6 +1,8 @@
import asyncio import asyncio
from functools import partial
from inspect import isawaitable from inspect import isawaitable
from signal import SIGINT, SIGTERM from signal import SIGINT, SIGTERM
from time import time
import httptools import httptools
@ -17,6 +19,9 @@ class Signal:
stopped = False stopped = False
current_time = None
class HttpProtocol(asyncio.Protocol): class HttpProtocol(asyncio.Protocol):
__slots__ = ( __slots__ = (
# event loop, connection # event loop, connection
@ -26,7 +31,7 @@ class HttpProtocol(asyncio.Protocol):
# request config # request config
'request_handler', 'request_timeout', 'request_max_size', 'request_handler', 'request_timeout', 'request_max_size',
# connection management # connection management
'_total_request_size', '_timeout_handler') '_total_request_size', '_timeout_handler', '_last_communication_time')
def __init__(self, *, loop, request_handler, signal=Signal(), def __init__(self, *, loop, request_handler, signal=Signal(),
connections={}, request_timeout=60, connections={}, request_timeout=60,
@ -44,6 +49,7 @@ class HttpProtocol(asyncio.Protocol):
self.request_max_size = request_max_size self.request_max_size = request_max_size
self._total_request_size = 0 self._total_request_size = 0
self._timeout_handler = None self._timeout_handler = None
self._last_request_time = None
# -------------------------------------------- # # -------------------------------------------- #
# Connection # Connection
@ -54,6 +60,7 @@ class HttpProtocol(asyncio.Protocol):
self._timeout_handler = self.loop.call_later( self._timeout_handler = self.loop.call_later(
self.request_timeout, self.connection_timeout) self.request_timeout, self.connection_timeout)
self.transport = transport self.transport = transport
self._last_request_time = current_time
def connection_lost(self, exc): def connection_lost(self, exc):
del self.connections[self] del self.connections[self]
@ -61,6 +68,13 @@ class HttpProtocol(asyncio.Protocol):
self.cleanup() self.cleanup()
def connection_timeout(self): def connection_timeout(self):
# Check if
time_elapsed = current_time - self._last_request_time
if time_elapsed < self.request_timeout:
time_left = self.request_timeout - time_elapsed
self._timeout_handler = \
self.loop.call_later(time_left, self.connection_timeout)
else:
self.bail_out("Request timed out, connection closed") self.bail_out("Request timed out, connection closed")
# -------------------------------------------- # # -------------------------------------------- #
@ -131,13 +145,15 @@ class HttpProtocol(asyncio.Protocol):
if not keep_alive: if not keep_alive:
self.transport.close() self.transport.close()
else: else:
# Record that we received data
self._last_request_time = current_time
self.cleanup() self.cleanup()
except Exception as e: except Exception as e:
self.bail_out( self.bail_out(
"Writing request failed, connection closed {}".format(e)) "Writing request failed, connection closed {}".format(e))
def bail_out(self, message): def bail_out(self, message):
log.error(message) log.debug(message)
self.transport.close() self.transport.close()
def cleanup(self): def cleanup(self):
@ -158,6 +174,18 @@ class HttpProtocol(asyncio.Protocol):
return False return False
def update_current_time(loop):
"""
Caches the current time, since it is needed
at the end of every keep-alive request to update the request timeout time
:param loop:
:return:
"""
global current_time
current_time = time()
loop.call_later(1, partial(update_current_time, loop))
def trigger_events(events, loop): def trigger_events(events, loop):
""" """
:param events: one or more sync or async functions to execute :param events: one or more sync or async functions to execute
@ -212,6 +240,10 @@ def serve(host, port, request_handler, before_start=None, after_start=None,
request_max_size=request_max_size, request_max_size=request_max_size,
), host, port, reuse_port=reuse_port, sock=sock) ), host, port, reuse_port=reuse_port, sock=sock)
# Instead of pulling time at the end of every request,
# pull it once per minute
loop.call_soon(partial(update_current_time, loop))
try: try:
http_server = loop.run_until_complete(server_coroutine) http_server = loop.run_until_complete(server_coroutine)
except Exception: except Exception: