Revert check for websocket protocol to use hasattr

This commit is contained in:
Adam Hopkins 2020-09-29 00:24:00 +03:00
parent 33aa4daac8
commit 13094e02bc
2 changed files with 31 additions and 79 deletions

View File

@ -6,7 +6,6 @@ import socket
import stat import stat
import sys import sys
import traceback import traceback
from collections import deque from collections import deque
from functools import partial from functools import partial
from inspect import isawaitable from inspect import isawaitable
@ -14,7 +13,7 @@ from ipaddress import ip_address
from signal import SIG_IGN, SIGINT, SIGTERM, Signals from signal import SIG_IGN, SIGINT, SIGTERM, Signals
from signal import signal as signal_func from signal import signal as signal_func
from time import time from time import time
from typing import Type from typing import Dict, Type, Union
from httptools import HttpRequestParser # type: ignore from httptools import HttpRequestParser # type: ignore
from httptools.parser.errors import HttpParserError # type: ignore from httptools.parser.errors import HttpParserError # type: ignore
@ -33,7 +32,6 @@ from sanic.log import access_logger, logger
from sanic.request import EXPECT_HEADER, Request, StreamBuffer from sanic.request import EXPECT_HEADER, Request, StreamBuffer
from sanic.response import HTTPResponse from sanic.response import HTTPResponse
try: try:
import uvloop # type: ignore import uvloop # type: ignore
@ -160,9 +158,7 @@ class HttpProtocol(asyncio.Protocol):
self.request_handler = self.app.handle_request self.request_handler = self.app.handle_request
self.error_handler = self.app.error_handler self.error_handler = self.app.error_handler
self.request_timeout = self.app.config.REQUEST_TIMEOUT self.request_timeout = self.app.config.REQUEST_TIMEOUT
self.request_buffer_queue_size = ( self.request_buffer_queue_size = self.app.config.REQUEST_BUFFER_QUEUE_SIZE
self.app.config.REQUEST_BUFFER_QUEUE_SIZE
)
self.response_timeout = self.app.config.RESPONSE_TIMEOUT self.response_timeout = self.app.config.RESPONSE_TIMEOUT
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
self.request_max_size = self.app.config.REQUEST_MAX_SIZE self.request_max_size = self.app.config.REQUEST_MAX_SIZE
@ -335,9 +331,7 @@ class HttpProtocol(asyncio.Protocol):
value = value.decode() value = value.decode()
except UnicodeDecodeError: except UnicodeDecodeError:
value = value.decode("latin_1") value = value.decode("latin_1")
self.headers.append( self.headers.append((self._header_fragment.decode().casefold(), value))
(self._header_fragment.decode().casefold(), value)
)
self._header_fragment = b"" self._header_fragment = b""
@ -361,13 +355,9 @@ class HttpProtocol(asyncio.Protocol):
self.expect_handler() self.expect_handler()
if self.is_request_stream: if self.is_request_stream:
self._is_stream_handler = self.app.router.is_stream_handler( self._is_stream_handler = self.app.router.is_stream_handler(self.request)
self.request
)
if self._is_stream_handler: if self._is_stream_handler:
self.request.stream = StreamBuffer( self.request.stream = StreamBuffer(self.request_buffer_queue_size)
self.request_buffer_queue_size
)
self.execute_request_handler() self.execute_request_handler()
def expect_handler(self): def expect_handler(self):
@ -379,9 +369,7 @@ class HttpProtocol(asyncio.Protocol):
if expect.lower() == "100-continue": if expect.lower() == "100-continue":
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
else: else:
self.write_error( self.write_error(HeaderExpectationFailed(f"Unknown Expect: {expect}"))
HeaderExpectationFailed(f"Unknown Expect: {expect}")
)
def on_body(self, body): def on_body(self, body):
if self.is_request_stream and self._is_stream_handler: if self.is_request_stream and self._is_stream_handler:
@ -390,13 +378,8 @@ class HttpProtocol(asyncio.Protocol):
# 3.7. so we should not create more than one task putting into the # 3.7. so we should not create more than one task putting into the
# queue simultaneously. # queue simultaneously.
self._body_chunks.append(body) self._body_chunks.append(body)
if ( if not self._request_stream_task or self._request_stream_task.done():
not self._request_stream_task self._request_stream_task = self.loop.create_task(self.stream_append())
or self._request_stream_task.done()
):
self._request_stream_task = self.loop.create_task(
self.stream_append()
)
else: else:
self.request.body_push(body) self.request.body_push(body)
@ -433,13 +416,8 @@ class HttpProtocol(asyncio.Protocol):
self._request_timeout_handler = None self._request_timeout_handler = None
if self.is_request_stream and self._is_stream_handler: if self.is_request_stream and self._is_stream_handler:
self._body_chunks.append(None) self._body_chunks.append(None)
if ( if not self._request_stream_task or self._request_stream_task.done():
not self._request_stream_task self._request_stream_task = self.loop.create_task(self.stream_append())
or self._request_stream_task.done()
):
self._request_stream_task = self.loop.create_task(
self.stream_append()
)
return return
self.request.body_finish() self.request.body_finish()
self.execute_request_handler() self.execute_request_handler()
@ -521,8 +499,7 @@ class HttpProtocol(asyncio.Protocol):
except RuntimeError: except RuntimeError:
if self.app.debug: if self.app.debug:
logger.error( logger.error(
"Connection lost before response written @ %s", "Connection lost before response written @ %s", self.request.ip,
self.request.ip,
) )
keep_alive = False keep_alive = False
except Exception as e: except Exception as e:
@ -572,8 +549,7 @@ class HttpProtocol(asyncio.Protocol):
except RuntimeError: except RuntimeError:
if self.app.debug: if self.app.debug:
logger.error( logger.error(
"Connection lost before response written @ %s", "Connection lost before response written @ %s", self.request.ip,
self.request.ip,
) )
keep_alive = False keep_alive = False
except Exception as e: except Exception as e:
@ -608,8 +584,7 @@ class HttpProtocol(asyncio.Protocol):
) )
except Exception as e: except Exception as e:
self.bail_out( self.bail_out(
f"Writing error failed, connection closed {e!r}", f"Writing error failed, connection closed {e!r}", from_error=True,
from_error=True,
) )
finally: finally:
if self.parser and ( if self.parser and (
@ -712,13 +687,7 @@ class AsyncioServer:
) )
def __init__( def __init__(
self, self, loop, serve_coro, connections, after_start, before_stop, after_stop,
loop,
serve_coro,
connections,
after_start,
before_stop,
after_stop,
): ):
# Note, Sanic already called "before_server_start" events # Note, Sanic already called "before_server_start" events
# before this helper was even created. So we don't need it here. # before this helper was even created. So we don't need it here.
@ -857,9 +826,7 @@ def serve(
unix=unix, unix=unix,
**protocol_kwargs, **protocol_kwargs,
) )
asyncio_server_kwargs = ( asyncio_server_kwargs = asyncio_server_kwargs if asyncio_server_kwargs else {}
asyncio_server_kwargs if asyncio_server_kwargs else {}
)
# UNIX sockets are always bound by us (to preserve semantics between modes) # UNIX sockets are always bound by us (to preserve semantics between modes)
if unix: if unix:
sock = bind_unix_socket(unix, backlog=backlog) sock = bind_unix_socket(unix, backlog=backlog)
@ -954,8 +921,8 @@ def serve(
def _build_protocol_kwargs( def _build_protocol_kwargs(
protocol: Type[HttpProtocol], config: Config protocol: Type[HttpProtocol], config: Config
) -> dict: ) -> Dict[str, Union[int, float]]:
if (dir(protocol).__contains__("websocket_handshake")): if hasattr(protocol, "websocket_handshake"):
return { return {
"websocket_max_size": config.WEBSOCKET_MAX_SIZE, "websocket_max_size": config.WEBSOCKET_MAX_SIZE,
"websocket_max_queue": config.WEBSOCKET_MAX_QUEUE, "websocket_max_queue": config.WEBSOCKET_MAX_QUEUE,
@ -977,9 +944,7 @@ def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket:
try: # IP address: family must be specified for IPv6 at least try: # IP address: family must be specified for IPv6 at least
ip = ip_address(host) ip = ip_address(host)
host = str(ip) host = str(ip)
sock = socket.socket( sock = socket.socket(socket.AF_INET6 if ip.version == 6 else socket.AF_INET)
socket.AF_INET6 if ip.version == 6 else socket.AF_INET
)
except ValueError: # Hostname, may become AF_INET or AF_INET6 except ValueError: # Hostname, may become AF_INET or AF_INET6
sock = socket.socket() sock = socket.socket()
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

View File

@ -1,9 +1,8 @@
import asyncio import asyncio
import logging import logging
import sys import sys
from unittest.mock import patch
from inspect import isawaitable from inspect import isawaitable
from unittest.mock import patch
import pytest import pytest
@ -31,9 +30,7 @@ def test_app_loop_running(app):
assert response.text == "pass" assert response.text == "pass"
@pytest.mark.skipif( @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher")
sys.version_info < (3, 7), reason="requires python3.7 or higher"
)
def test_create_asyncio_server(app): def test_create_asyncio_server(app):
if not uvloop_installed(): if not uvloop_installed():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -43,9 +40,7 @@ def test_create_asyncio_server(app):
assert srv.is_serving() is True assert srv.is_serving() is True
@pytest.mark.skipif( @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher")
sys.version_info < (3, 7), reason="requires python3.7 or higher"
)
def test_asyncio_server_no_start_serving(app): def test_asyncio_server_no_start_serving(app):
if not uvloop_installed(): if not uvloop_installed():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -58,9 +53,7 @@ def test_asyncio_server_no_start_serving(app):
assert srv.is_serving() is False assert srv.is_serving() is False
@pytest.mark.skipif( @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher")
sys.version_info < (3, 7), reason="requires python3.7 or higher"
)
def test_asyncio_server_start_serving(app): def test_asyncio_server_start_serving(app):
if not uvloop_installed(): if not uvloop_installed():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -170,12 +163,12 @@ def test_app_websocket_parameters(websocket_protocol_mock, app):
websocket_protocol_call_args = websocket_protocol_mock.call_args websocket_protocol_call_args = websocket_protocol_mock.call_args
ws_kwargs = websocket_protocol_call_args[1] ws_kwargs = websocket_protocol_call_args[1]
assert ws_kwargs["max_size"] == app.config.WEBSOCKET_MAX_SIZE assert ws_kwargs["websocket_max_size"] == app.config.WEBSOCKET_MAX_SIZE
assert ws_kwargs["max_queue"] == app.config.WEBSOCKET_MAX_QUEUE assert ws_kwargs["websocket_max_queue"] == app.config.WEBSOCKET_MAX_QUEUE
assert ws_kwargs["read_limit"] == app.config.WEBSOCKET_READ_LIMIT assert ws_kwargs["websocket_read_limit"] == app.config.WEBSOCKET_READ_LIMIT
assert ws_kwargs["write_limit"] == app.config.WEBSOCKET_WRITE_LIMIT assert ws_kwargs["websocket_write_limit"] == app.config.WEBSOCKET_WRITE_LIMIT
assert ws_kwargs["ping_timeout"] == app.config.WEBSOCKET_PING_TIMEOUT assert ws_kwargs["websocket_ping_timeout"] == app.config.WEBSOCKET_PING_TIMEOUT
assert ws_kwargs["ping_interval"] == app.config.WEBSOCKET_PING_INTERVAL assert ws_kwargs["websocket_ping_interval"] == app.config.WEBSOCKET_PING_INTERVAL
def test_handle_request_with_nested_exception(app, monkeypatch): def test_handle_request_with_nested_exception(app, monkeypatch):
@ -186,9 +179,7 @@ def test_handle_request_with_nested_exception(app, monkeypatch):
def mock_error_handler_response(*args, **kwargs): def mock_error_handler_response(*args, **kwargs):
raise Exception(err_msg) raise Exception(err_msg)
monkeypatch.setattr( monkeypatch.setattr(app.error_handler, "response", mock_error_handler_response)
app.error_handler, "response", mock_error_handler_response
)
@app.get("/") @app.get("/")
def handler(request): def handler(request):
@ -207,9 +198,7 @@ def test_handle_request_with_nested_exception_debug(app, monkeypatch):
def mock_error_handler_response(*args, **kwargs): def mock_error_handler_response(*args, **kwargs):
raise Exception(err_msg) raise Exception(err_msg)
monkeypatch.setattr( monkeypatch.setattr(app.error_handler, "response", mock_error_handler_response)
app.error_handler, "response", mock_error_handler_response
)
@app.get("/") @app.get("/")
def handler(request): def handler(request):
@ -228,9 +217,7 @@ def test_handle_request_with_nested_sanic_exception(app, monkeypatch, caplog):
def mock_error_handler_response(*args, **kwargs): def mock_error_handler_response(*args, **kwargs):
raise SanicException("Mock SanicException") raise SanicException("Mock SanicException")
monkeypatch.setattr( monkeypatch.setattr(app.error_handler, "response", mock_error_handler_response)
app.error_handler, "response", mock_error_handler_response
)
@app.get("/") @app.get("/")
def handler(request): def handler(request):