Prepare initial websocket support

This commit is contained in:
Adam Hopkins
2019-05-22 01:42:19 +03:00
parent 8a56da84e6
commit 7b8e3624b8
10 changed files with 207 additions and 995 deletions

View File

@@ -8,7 +8,6 @@ from asyncio import CancelledError, Protocol, ensure_future, get_event_loop
from collections import defaultdict, deque
from functools import partial
from inspect import getmodulename, isawaitable, signature, stack
from multidict import CIMultiDict
from socket import socket
from ssl import Purpose, SSLContext, create_default_context
from traceback import format_exc
@@ -24,11 +23,10 @@ from sanic.exceptions import SanicException, ServerError, URLBuildError
from sanic.handlers import ErrorHandler
from sanic.log import LOGGING_CONFIG_DEFAULTS, error_logger, logger
from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.request import Request
from sanic.router import Router
from sanic.server import HttpProtocol, Signal, serve, serve_multiple
from sanic.static import register as static_register
from sanic.testing import SanicTestClient, SanicASGITestClient
from sanic.testing import SanicASGITestClient, SanicTestClient
from sanic.views import CompositionView
from sanic.websocket import ConnectionClosed, WebSocketProtocol
@@ -56,6 +54,7 @@ class Sanic:
logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS)
self.name = name
self.asgi = True
self.router = router or Router()
self.request_class = request_class
self.error_handler = error_handler or ErrorHandler()
@@ -468,13 +467,23 @@ class Sanic:
getattr(handler, "__blueprintname__", "")
+ handler.__name__
)
try:
protocol = request.transport.get_protocol()
except AttributeError:
# On Python3.5 the Transport classes in asyncio do not
# have a get_protocol() method as in uvloop
protocol = request.transport._protocol
ws = await protocol.websocket_handshake(request, subprotocols)
pass
if self.asgi:
ws = request.transport.get_websocket_connection()
else:
try:
protocol = request.transport.get_protocol()
except AttributeError:
# On Python3.5 the Transport classes in asyncio do not
# have a get_protocol() method as in uvloop
protocol = request.transport._protocol
protocol.app = self
ws = await protocol.websocket_handshake(
request, subprotocols
)
# schedule the application handler
# its future is kept in self.websocket_tasks in case it
@@ -985,7 +994,13 @@ class Sanic:
if write_callback is None or isinstance(
response, StreamingHTTPResponse
):
await stream_callback(response)
if stream_callback:
await stream_callback(response)
else:
# Should only end here IF it is an ASGI websocket.
# TODO:
# - Add exception handling
pass
else:
write_callback(response)
@@ -1374,5 +1389,5 @@ class Sanic:
# -------------------------------------------------------------------- #
async def __call__(self, scope, receive, send):
asgi_app = ASGIApp(self, scope, receive, send)
asgi_app = await ASGIApp.create(self, scope, receive, send)
await asgi_app()

View File

@@ -1,24 +1,50 @@
from sanic.request import Request
from multidict import CIMultiDict
from sanic.response import StreamingHTTPResponse
from typing import Any, Awaitable, Callable, MutableMapping, Union
from multidict import CIMultiDict
from sanic.request import Request
from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.websocket import WebSocketConnection
ASGIScope = MutableMapping[str, Any]
ASGIMessage = MutableMapping[str, Any]
ASGISend = Callable[[ASGIMessage], Awaitable[None]]
ASGIReceive = Callable[[], Awaitable[ASGIMessage]]
class MockTransport:
def __init__(self, scope):
def __init__(self, scope: ASGIScope) -> None:
self.scope = scope
def get_extra_info(self, info):
def get_extra_info(self, info: str) -> Union[str, bool]:
if info == "peername":
return self.scope.get("server")
elif info == "sslcontext":
return self.scope.get("scheme") in ["https", "wss"]
def get_websocket_connection(self) -> WebSocketConnection:
return self._websocket_connection
def create_websocket_connection(
self,
send: ASGISend,
receive: ASGIReceive,
) -> WebSocketConnection:
self._websocket_connection = WebSocketConnection(send, receive)
return self._websocket_connection
class ASGIApp:
def __init__(self, sanic_app, scope, receive, send):
self.sanic_app = sanic_app
self.receive = receive
self.send = send
def __init__(self) -> None:
self.ws = None
@classmethod
async def create(cls, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend) -> "ASGIApp":
instance = cls()
instance.sanic_app = sanic_app
instance.receive = receive
instance.send = send
url_bytes = scope.get("root_path", "") + scope["path"]
url_bytes = url_bytes.encode("latin-1")
url_bytes += scope["query_string"]
@@ -28,18 +54,30 @@ class ASGIApp:
for key, value in scope.get("headers", [])
]
)
version = scope["http_version"]
method = scope["method"]
self.request = Request(
url_bytes,
headers,
version,
method,
MockTransport(scope),
sanic_app,
transport = MockTransport(scope)
if scope["type"] == "http":
version = scope["http_version"]
method = scope["method"]
elif scope["type"] == "websocket":
version = "1.1"
method = "GET"
instance.ws = transport.create_websocket_connection(send, receive)
await instance.ws.accept()
else:
pass
# TODO:
# - close connection
instance.request = Request(
url_bytes, headers, version, method, transport, sanic_app
)
async def read_body(self):
return instance
async def read_body(self) -> bytes:
"""
Read and return the entire body from an incoming ASGI message.
"""
@@ -53,15 +91,16 @@ class ASGIApp:
return body
async def __call__(self):
async def __call__(self) -> None:
"""
Handle the incoming request.
"""
self.request.body = await self.read_body()
handler = self.sanic_app.handle_request
await handler(self.request, None, self.stream_callback)
callback = None if self.ws else self.stream_callback
await handler(self.request, None, callback)
async def stream_callback(self, response):
async def stream_callback(self, response: HTTPResponse) -> None:
"""
Write the response.
"""

View File

@@ -708,6 +708,8 @@ def serve(
if debug:
loop.set_debug(debug)
app.asgi = False
connections = connections if connections is not None else set()
server = partial(
protocol,

View File

@@ -1,16 +1,21 @@
import typing
import types
import asyncio
from json import JSONDecodeError
from socket import socket
from urllib.parse import unquote, urljoin, urlsplit
from urllib.parse import unquote, urlsplit
import httpcore
import requests_async as requests
import typing
import websockets
from sanic.asgi import ASGIApp
from sanic.exceptions import MethodNotSupported
from sanic.log import logger
from sanic.response import text
HOST = "127.0.0.1"
PORT = 42101
@@ -314,7 +319,7 @@ class TestASGIApp(ASGIApp):
async def app_call_with_return(self, scope, receive, send):
asgi_app = TestASGIApp(self, scope, receive, send)
asgi_app = await TestASGIApp.create(self, scope, receive, send)
return await asgi_app()

View File

@@ -1,3 +1,5 @@
from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union
from httptools import HttpParserUpgrade
from websockets import ConnectionClosed # noqa
from websockets import InvalidHandshake, WebSocketCommonProtocol, handshake
@@ -6,6 +8,9 @@ from sanic.exceptions import InvalidUsage
from sanic.server import HttpProtocol
ASIMessage = MutableMapping[str, Any]
class WebSocketProtocol(HttpProtocol):
def __init__(
self,
@@ -19,6 +24,7 @@ class WebSocketProtocol(HttpProtocol):
):
super().__init__(*args, **kwargs)
self.websocket = None
self.app = None
self.websocket_timeout = websocket_timeout
self.websocket_max_size = websocket_max_size
self.websocket_max_queue = websocket_max_queue
@@ -103,3 +109,46 @@ class WebSocketProtocol(HttpProtocol):
self.websocket.connection_made(request.transport)
self.websocket.connection_open()
return self.websocket
class WebSocketConnection:
# TODO
# - Implement ping/pong
def __init__(
self,
send: Callable[[ASIMessage], Awaitable[None]],
receive: Callable[[], Awaitable[ASIMessage]],
) -> None:
self._send = send
self._receive = receive
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
message = {"type": "websocket.send"}
try:
data.decode()
except AttributeError:
message.update({"text": str(data)})
else:
message.update({"bytes": data})
await self._send(message)
async def recv(self, *args, **kwargs) -> Optional[str]:
message = await self._receive()
if message["type"] == "websocket.receive":
return message["text"]
elif message["type"] == "websocket.disconnect":
pass
# await self._send({
# "type": "websocket.close"
# })
async def accept(self) -> None:
await self._send({"type": "websocket.accept", "subprotocol": ""})
async def close(self) -> None:
pass