Revert check for websocket protocol to use hasattr

This commit is contained in:
Adam Hopkins
2020-09-29 00:24:00 +03:00
parent 33aa4daac8
commit 13094e02bc
2 changed files with 31 additions and 79 deletions

View File

@@ -6,7 +6,6 @@ import socket
import stat
import sys
import traceback
from collections import deque
from functools import partial
from inspect import isawaitable
@@ -14,7 +13,7 @@ from ipaddress import ip_address
from signal import SIG_IGN, SIGINT, SIGTERM, Signals
from signal import signal as signal_func
from time import time
from typing import Type
from typing import Dict, Type, Union
from httptools import HttpRequestParser # type: ignore
from httptools.parser.errors import HttpParserError # type: ignore
@@ -33,7 +32,6 @@ from sanic.log import access_logger, logger
from sanic.request import EXPECT_HEADER, Request, StreamBuffer
from sanic.response import HTTPResponse
try:
import uvloop # type: ignore
@@ -160,9 +158,7 @@ class HttpProtocol(asyncio.Protocol):
self.request_handler = self.app.handle_request
self.error_handler = self.app.error_handler
self.request_timeout = self.app.config.REQUEST_TIMEOUT
self.request_buffer_queue_size = (
self.app.config.REQUEST_BUFFER_QUEUE_SIZE
)
self.request_buffer_queue_size = self.app.config.REQUEST_BUFFER_QUEUE_SIZE
self.response_timeout = self.app.config.RESPONSE_TIMEOUT
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
self.request_max_size = self.app.config.REQUEST_MAX_SIZE
@@ -335,9 +331,7 @@ class HttpProtocol(asyncio.Protocol):
value = value.decode()
except UnicodeDecodeError:
value = value.decode("latin_1")
self.headers.append(
(self._header_fragment.decode().casefold(), value)
)
self.headers.append((self._header_fragment.decode().casefold(), value))
self._header_fragment = b""
@@ -361,13 +355,9 @@ class HttpProtocol(asyncio.Protocol):
self.expect_handler()
if self.is_request_stream:
self._is_stream_handler = self.app.router.is_stream_handler(
self.request
)
self._is_stream_handler = self.app.router.is_stream_handler(self.request)
if self._is_stream_handler:
self.request.stream = StreamBuffer(
self.request_buffer_queue_size
)
self.request.stream = StreamBuffer(self.request_buffer_queue_size)
self.execute_request_handler()
def expect_handler(self):
@@ -379,9 +369,7 @@ class HttpProtocol(asyncio.Protocol):
if expect.lower() == "100-continue":
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
else:
self.write_error(
HeaderExpectationFailed(f"Unknown Expect: {expect}")
)
self.write_error(HeaderExpectationFailed(f"Unknown Expect: {expect}"))
def on_body(self, body):
if self.is_request_stream and self._is_stream_handler:
@@ -390,13 +378,8 @@ class HttpProtocol(asyncio.Protocol):
# 3.7. so we should not create more than one task putting into the
# queue simultaneously.
self._body_chunks.append(body)
if (
not self._request_stream_task
or self._request_stream_task.done()
):
self._request_stream_task = self.loop.create_task(
self.stream_append()
)
if not self._request_stream_task or self._request_stream_task.done():
self._request_stream_task = self.loop.create_task(self.stream_append())
else:
self.request.body_push(body)
@@ -433,13 +416,8 @@ class HttpProtocol(asyncio.Protocol):
self._request_timeout_handler = None
if self.is_request_stream and self._is_stream_handler:
self._body_chunks.append(None)
if (
not self._request_stream_task
or self._request_stream_task.done()
):
self._request_stream_task = self.loop.create_task(
self.stream_append()
)
if not self._request_stream_task or self._request_stream_task.done():
self._request_stream_task = self.loop.create_task(self.stream_append())
return
self.request.body_finish()
self.execute_request_handler()
@@ -521,8 +499,7 @@ class HttpProtocol(asyncio.Protocol):
except RuntimeError:
if self.app.debug:
logger.error(
"Connection lost before response written @ %s",
self.request.ip,
"Connection lost before response written @ %s", self.request.ip,
)
keep_alive = False
except Exception as e:
@@ -572,8 +549,7 @@ class HttpProtocol(asyncio.Protocol):
except RuntimeError:
if self.app.debug:
logger.error(
"Connection lost before response written @ %s",
self.request.ip,
"Connection lost before response written @ %s", self.request.ip,
)
keep_alive = False
except Exception as e:
@@ -608,8 +584,7 @@ class HttpProtocol(asyncio.Protocol):
)
except Exception as e:
self.bail_out(
f"Writing error failed, connection closed {e!r}",
from_error=True,
f"Writing error failed, connection closed {e!r}", from_error=True,
)
finally:
if self.parser and (
@@ -712,13 +687,7 @@ class AsyncioServer:
)
def __init__(
self,
loop,
serve_coro,
connections,
after_start,
before_stop,
after_stop,
self, loop, serve_coro, connections, after_start, before_stop, after_stop,
):
# Note, Sanic already called "before_server_start" events
# before this helper was even created. So we don't need it here.
@@ -857,9 +826,7 @@ def serve(
unix=unix,
**protocol_kwargs,
)
asyncio_server_kwargs = (
asyncio_server_kwargs if asyncio_server_kwargs else {}
)
asyncio_server_kwargs = asyncio_server_kwargs if asyncio_server_kwargs else {}
# UNIX sockets are always bound by us (to preserve semantics between modes)
if unix:
sock = bind_unix_socket(unix, backlog=backlog)
@@ -954,8 +921,8 @@ def serve(
def _build_protocol_kwargs(
protocol: Type[HttpProtocol], config: Config
) -> dict:
if (dir(protocol).__contains__("websocket_handshake")):
) -> Dict[str, Union[int, float]]:
if hasattr(protocol, "websocket_handshake"):
return {
"websocket_max_size": config.WEBSOCKET_MAX_SIZE,
"websocket_max_queue": config.WEBSOCKET_MAX_QUEUE,
@@ -977,9 +944,7 @@ def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket:
try: # IP address: family must be specified for IPv6 at least
ip = ip_address(host)
host = str(ip)
sock = socket.socket(
socket.AF_INET6 if ip.version == 6 else socket.AF_INET
)
sock = socket.socket(socket.AF_INET6 if ip.version == 6 else socket.AF_INET)
except ValueError: # Hostname, may become AF_INET or AF_INET6
sock = socket.socket()
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)