Revert check for websocket protocol to use hasattr
This commit is contained in:
parent
33aa4daac8
commit
13094e02bc
@ -6,7 +6,6 @@ import socket
|
|||||||
import stat
|
import stat
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import isawaitable
|
from inspect import isawaitable
|
||||||
@ -14,7 +13,7 @@ from ipaddress import ip_address
|
|||||||
from signal import SIG_IGN, SIGINT, SIGTERM, Signals
|
from signal import SIG_IGN, SIGINT, SIGTERM, Signals
|
||||||
from signal import signal as signal_func
|
from signal import signal as signal_func
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Type
|
from typing import Dict, Type, Union
|
||||||
|
|
||||||
from httptools import HttpRequestParser # type: ignore
|
from httptools import HttpRequestParser # type: ignore
|
||||||
from httptools.parser.errors import HttpParserError # type: ignore
|
from httptools.parser.errors import HttpParserError # type: ignore
|
||||||
@ -33,7 +32,6 @@ from sanic.log import access_logger, logger
|
|||||||
from sanic.request import EXPECT_HEADER, Request, StreamBuffer
|
from sanic.request import EXPECT_HEADER, Request, StreamBuffer
|
||||||
from sanic.response import HTTPResponse
|
from sanic.response import HTTPResponse
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import uvloop # type: ignore
|
import uvloop # type: ignore
|
||||||
|
|
||||||
@ -160,9 +158,7 @@ class HttpProtocol(asyncio.Protocol):
|
|||||||
self.request_handler = self.app.handle_request
|
self.request_handler = self.app.handle_request
|
||||||
self.error_handler = self.app.error_handler
|
self.error_handler = self.app.error_handler
|
||||||
self.request_timeout = self.app.config.REQUEST_TIMEOUT
|
self.request_timeout = self.app.config.REQUEST_TIMEOUT
|
||||||
self.request_buffer_queue_size = (
|
self.request_buffer_queue_size = self.app.config.REQUEST_BUFFER_QUEUE_SIZE
|
||||||
self.app.config.REQUEST_BUFFER_QUEUE_SIZE
|
|
||||||
)
|
|
||||||
self.response_timeout = self.app.config.RESPONSE_TIMEOUT
|
self.response_timeout = self.app.config.RESPONSE_TIMEOUT
|
||||||
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
|
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
|
||||||
self.request_max_size = self.app.config.REQUEST_MAX_SIZE
|
self.request_max_size = self.app.config.REQUEST_MAX_SIZE
|
||||||
@ -335,9 +331,7 @@ class HttpProtocol(asyncio.Protocol):
|
|||||||
value = value.decode()
|
value = value.decode()
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
value = value.decode("latin_1")
|
value = value.decode("latin_1")
|
||||||
self.headers.append(
|
self.headers.append((self._header_fragment.decode().casefold(), value))
|
||||||
(self._header_fragment.decode().casefold(), value)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._header_fragment = b""
|
self._header_fragment = b""
|
||||||
|
|
||||||
@ -361,13 +355,9 @@ class HttpProtocol(asyncio.Protocol):
|
|||||||
self.expect_handler()
|
self.expect_handler()
|
||||||
|
|
||||||
if self.is_request_stream:
|
if self.is_request_stream:
|
||||||
self._is_stream_handler = self.app.router.is_stream_handler(
|
self._is_stream_handler = self.app.router.is_stream_handler(self.request)
|
||||||
self.request
|
|
||||||
)
|
|
||||||
if self._is_stream_handler:
|
if self._is_stream_handler:
|
||||||
self.request.stream = StreamBuffer(
|
self.request.stream = StreamBuffer(self.request_buffer_queue_size)
|
||||||
self.request_buffer_queue_size
|
|
||||||
)
|
|
||||||
self.execute_request_handler()
|
self.execute_request_handler()
|
||||||
|
|
||||||
def expect_handler(self):
|
def expect_handler(self):
|
||||||
@ -379,9 +369,7 @@ class HttpProtocol(asyncio.Protocol):
|
|||||||
if expect.lower() == "100-continue":
|
if expect.lower() == "100-continue":
|
||||||
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
|
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
|
||||||
else:
|
else:
|
||||||
self.write_error(
|
self.write_error(HeaderExpectationFailed(f"Unknown Expect: {expect}"))
|
||||||
HeaderExpectationFailed(f"Unknown Expect: {expect}")
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_body(self, body):
|
def on_body(self, body):
|
||||||
if self.is_request_stream and self._is_stream_handler:
|
if self.is_request_stream and self._is_stream_handler:
|
||||||
@ -390,13 +378,8 @@ class HttpProtocol(asyncio.Protocol):
|
|||||||
# 3.7. so we should not create more than one task putting into the
|
# 3.7. so we should not create more than one task putting into the
|
||||||
# queue simultaneously.
|
# queue simultaneously.
|
||||||
self._body_chunks.append(body)
|
self._body_chunks.append(body)
|
||||||
if (
|
if not self._request_stream_task or self._request_stream_task.done():
|
||||||
not self._request_stream_task
|
self._request_stream_task = self.loop.create_task(self.stream_append())
|
||||||
or self._request_stream_task.done()
|
|
||||||
):
|
|
||||||
self._request_stream_task = self.loop.create_task(
|
|
||||||
self.stream_append()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.request.body_push(body)
|
self.request.body_push(body)
|
||||||
|
|
||||||
@ -433,13 +416,8 @@ class HttpProtocol(asyncio.Protocol):
|
|||||||
self._request_timeout_handler = None
|
self._request_timeout_handler = None
|
||||||
if self.is_request_stream and self._is_stream_handler:
|
if self.is_request_stream and self._is_stream_handler:
|
||||||
self._body_chunks.append(None)
|
self._body_chunks.append(None)
|
||||||
if (
|
if not self._request_stream_task or self._request_stream_task.done():
|
||||||
not self._request_stream_task
|
self._request_stream_task = self.loop.create_task(self.stream_append())
|
||||||
or self._request_stream_task.done()
|
|
||||||
):
|
|
||||||
self._request_stream_task = self.loop.create_task(
|
|
||||||
self.stream_append()
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
self.request.body_finish()
|
self.request.body_finish()
|
||||||
self.execute_request_handler()
|
self.execute_request_handler()
|
||||||
@ -521,8 +499,7 @@ class HttpProtocol(asyncio.Protocol):
|
|||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
if self.app.debug:
|
if self.app.debug:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Connection lost before response written @ %s",
|
"Connection lost before response written @ %s", self.request.ip,
|
||||||
self.request.ip,
|
|
||||||
)
|
)
|
||||||
keep_alive = False
|
keep_alive = False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -572,8 +549,7 @@ class HttpProtocol(asyncio.Protocol):
|
|||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
if self.app.debug:
|
if self.app.debug:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Connection lost before response written @ %s",
|
"Connection lost before response written @ %s", self.request.ip,
|
||||||
self.request.ip,
|
|
||||||
)
|
)
|
||||||
keep_alive = False
|
keep_alive = False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -608,8 +584,7 @@ class HttpProtocol(asyncio.Protocol):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.bail_out(
|
self.bail_out(
|
||||||
f"Writing error failed, connection closed {e!r}",
|
f"Writing error failed, connection closed {e!r}", from_error=True,
|
||||||
from_error=True,
|
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
if self.parser and (
|
if self.parser and (
|
||||||
@ -712,13 +687,7 @@ class AsyncioServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, loop, serve_coro, connections, after_start, before_stop, after_stop,
|
||||||
loop,
|
|
||||||
serve_coro,
|
|
||||||
connections,
|
|
||||||
after_start,
|
|
||||||
before_stop,
|
|
||||||
after_stop,
|
|
||||||
):
|
):
|
||||||
# Note, Sanic already called "before_server_start" events
|
# Note, Sanic already called "before_server_start" events
|
||||||
# before this helper was even created. So we don't need it here.
|
# before this helper was even created. So we don't need it here.
|
||||||
@ -857,9 +826,7 @@ def serve(
|
|||||||
unix=unix,
|
unix=unix,
|
||||||
**protocol_kwargs,
|
**protocol_kwargs,
|
||||||
)
|
)
|
||||||
asyncio_server_kwargs = (
|
asyncio_server_kwargs = asyncio_server_kwargs if asyncio_server_kwargs else {}
|
||||||
asyncio_server_kwargs if asyncio_server_kwargs else {}
|
|
||||||
)
|
|
||||||
# UNIX sockets are always bound by us (to preserve semantics between modes)
|
# UNIX sockets are always bound by us (to preserve semantics between modes)
|
||||||
if unix:
|
if unix:
|
||||||
sock = bind_unix_socket(unix, backlog=backlog)
|
sock = bind_unix_socket(unix, backlog=backlog)
|
||||||
@ -954,8 +921,8 @@ def serve(
|
|||||||
|
|
||||||
def _build_protocol_kwargs(
|
def _build_protocol_kwargs(
|
||||||
protocol: Type[HttpProtocol], config: Config
|
protocol: Type[HttpProtocol], config: Config
|
||||||
) -> dict:
|
) -> Dict[str, Union[int, float]]:
|
||||||
if (dir(protocol).__contains__("websocket_handshake")):
|
if hasattr(protocol, "websocket_handshake"):
|
||||||
return {
|
return {
|
||||||
"websocket_max_size": config.WEBSOCKET_MAX_SIZE,
|
"websocket_max_size": config.WEBSOCKET_MAX_SIZE,
|
||||||
"websocket_max_queue": config.WEBSOCKET_MAX_QUEUE,
|
"websocket_max_queue": config.WEBSOCKET_MAX_QUEUE,
|
||||||
@ -977,9 +944,7 @@ def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket:
|
|||||||
try: # IP address: family must be specified for IPv6 at least
|
try: # IP address: family must be specified for IPv6 at least
|
||||||
ip = ip_address(host)
|
ip = ip_address(host)
|
||||||
host = str(ip)
|
host = str(ip)
|
||||||
sock = socket.socket(
|
sock = socket.socket(socket.AF_INET6 if ip.version == 6 else socket.AF_INET)
|
||||||
socket.AF_INET6 if ip.version == 6 else socket.AF_INET
|
|
||||||
)
|
|
||||||
except ValueError: # Hostname, may become AF_INET or AF_INET6
|
except ValueError: # Hostname, may become AF_INET or AF_INET6
|
||||||
sock = socket.socket()
|
sock = socket.socket()
|
||||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from inspect import isawaitable
|
from inspect import isawaitable
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -31,9 +30,7 @@ def test_app_loop_running(app):
|
|||||||
assert response.text == "pass"
|
assert response.text == "pass"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher")
|
||||||
sys.version_info < (3, 7), reason="requires python3.7 or higher"
|
|
||||||
)
|
|
||||||
def test_create_asyncio_server(app):
|
def test_create_asyncio_server(app):
|
||||||
if not uvloop_installed():
|
if not uvloop_installed():
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@ -43,9 +40,7 @@ def test_create_asyncio_server(app):
|
|||||||
assert srv.is_serving() is True
|
assert srv.is_serving() is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher")
|
||||||
sys.version_info < (3, 7), reason="requires python3.7 or higher"
|
|
||||||
)
|
|
||||||
def test_asyncio_server_no_start_serving(app):
|
def test_asyncio_server_no_start_serving(app):
|
||||||
if not uvloop_installed():
|
if not uvloop_installed():
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@ -58,9 +53,7 @@ def test_asyncio_server_no_start_serving(app):
|
|||||||
assert srv.is_serving() is False
|
assert srv.is_serving() is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher")
|
||||||
sys.version_info < (3, 7), reason="requires python3.7 or higher"
|
|
||||||
)
|
|
||||||
def test_asyncio_server_start_serving(app):
|
def test_asyncio_server_start_serving(app):
|
||||||
if not uvloop_installed():
|
if not uvloop_installed():
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@ -170,12 +163,12 @@ def test_app_websocket_parameters(websocket_protocol_mock, app):
|
|||||||
|
|
||||||
websocket_protocol_call_args = websocket_protocol_mock.call_args
|
websocket_protocol_call_args = websocket_protocol_mock.call_args
|
||||||
ws_kwargs = websocket_protocol_call_args[1]
|
ws_kwargs = websocket_protocol_call_args[1]
|
||||||
assert ws_kwargs["max_size"] == app.config.WEBSOCKET_MAX_SIZE
|
assert ws_kwargs["websocket_max_size"] == app.config.WEBSOCKET_MAX_SIZE
|
||||||
assert ws_kwargs["max_queue"] == app.config.WEBSOCKET_MAX_QUEUE
|
assert ws_kwargs["websocket_max_queue"] == app.config.WEBSOCKET_MAX_QUEUE
|
||||||
assert ws_kwargs["read_limit"] == app.config.WEBSOCKET_READ_LIMIT
|
assert ws_kwargs["websocket_read_limit"] == app.config.WEBSOCKET_READ_LIMIT
|
||||||
assert ws_kwargs["write_limit"] == app.config.WEBSOCKET_WRITE_LIMIT
|
assert ws_kwargs["websocket_write_limit"] == app.config.WEBSOCKET_WRITE_LIMIT
|
||||||
assert ws_kwargs["ping_timeout"] == app.config.WEBSOCKET_PING_TIMEOUT
|
assert ws_kwargs["websocket_ping_timeout"] == app.config.WEBSOCKET_PING_TIMEOUT
|
||||||
assert ws_kwargs["ping_interval"] == app.config.WEBSOCKET_PING_INTERVAL
|
assert ws_kwargs["websocket_ping_interval"] == app.config.WEBSOCKET_PING_INTERVAL
|
||||||
|
|
||||||
|
|
||||||
def test_handle_request_with_nested_exception(app, monkeypatch):
|
def test_handle_request_with_nested_exception(app, monkeypatch):
|
||||||
@ -186,9 +179,7 @@ def test_handle_request_with_nested_exception(app, monkeypatch):
|
|||||||
def mock_error_handler_response(*args, **kwargs):
|
def mock_error_handler_response(*args, **kwargs):
|
||||||
raise Exception(err_msg)
|
raise Exception(err_msg)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(app.error_handler, "response", mock_error_handler_response)
|
||||||
app.error_handler, "response", mock_error_handler_response
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
def handler(request):
|
def handler(request):
|
||||||
@ -207,9 +198,7 @@ def test_handle_request_with_nested_exception_debug(app, monkeypatch):
|
|||||||
def mock_error_handler_response(*args, **kwargs):
|
def mock_error_handler_response(*args, **kwargs):
|
||||||
raise Exception(err_msg)
|
raise Exception(err_msg)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(app.error_handler, "response", mock_error_handler_response)
|
||||||
app.error_handler, "response", mock_error_handler_response
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
def handler(request):
|
def handler(request):
|
||||||
@ -228,9 +217,7 @@ def test_handle_request_with_nested_sanic_exception(app, monkeypatch, caplog):
|
|||||||
def mock_error_handler_response(*args, **kwargs):
|
def mock_error_handler_response(*args, **kwargs):
|
||||||
raise SanicException("Mock SanicException")
|
raise SanicException("Mock SanicException")
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(app.error_handler, "response", mock_error_handler_response)
|
||||||
app.error_handler, "response", mock_error_handler_response
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
def handler(request):
|
def handler(request):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user