Prepare initial websocket support
This commit is contained in:
39
sanic/app.py
39
sanic/app.py
@@ -8,7 +8,6 @@ from asyncio import CancelledError, Protocol, ensure_future, get_event_loop
|
||||
from collections import defaultdict, deque
|
||||
from functools import partial
|
||||
from inspect import getmodulename, isawaitable, signature, stack
|
||||
from multidict import CIMultiDict
|
||||
from socket import socket
|
||||
from ssl import Purpose, SSLContext, create_default_context
|
||||
from traceback import format_exc
|
||||
@@ -24,11 +23,10 @@ from sanic.exceptions import SanicException, ServerError, URLBuildError
|
||||
from sanic.handlers import ErrorHandler
|
||||
from sanic.log import LOGGING_CONFIG_DEFAULTS, error_logger, logger
|
||||
from sanic.response import HTTPResponse, StreamingHTTPResponse
|
||||
from sanic.request import Request
|
||||
from sanic.router import Router
|
||||
from sanic.server import HttpProtocol, Signal, serve, serve_multiple
|
||||
from sanic.static import register as static_register
|
||||
from sanic.testing import SanicTestClient, SanicASGITestClient
|
||||
from sanic.testing import SanicASGITestClient, SanicTestClient
|
||||
from sanic.views import CompositionView
|
||||
from sanic.websocket import ConnectionClosed, WebSocketProtocol
|
||||
|
||||
@@ -56,6 +54,7 @@ class Sanic:
|
||||
logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS)
|
||||
|
||||
self.name = name
|
||||
self.asgi = True
|
||||
self.router = router or Router()
|
||||
self.request_class = request_class
|
||||
self.error_handler = error_handler or ErrorHandler()
|
||||
@@ -468,13 +467,23 @@ class Sanic:
|
||||
getattr(handler, "__blueprintname__", "")
|
||||
+ handler.__name__
|
||||
)
|
||||
try:
|
||||
protocol = request.transport.get_protocol()
|
||||
except AttributeError:
|
||||
# On Python3.5 the Transport classes in asyncio do not
|
||||
# have a get_protocol() method as in uvloop
|
||||
protocol = request.transport._protocol
|
||||
ws = await protocol.websocket_handshake(request, subprotocols)
|
||||
|
||||
pass
|
||||
|
||||
if self.asgi:
|
||||
ws = request.transport.get_websocket_connection()
|
||||
else:
|
||||
try:
|
||||
protocol = request.transport.get_protocol()
|
||||
except AttributeError:
|
||||
# On Python3.5 the Transport classes in asyncio do not
|
||||
# have a get_protocol() method as in uvloop
|
||||
protocol = request.transport._protocol
|
||||
protocol.app = self
|
||||
|
||||
ws = await protocol.websocket_handshake(
|
||||
request, subprotocols
|
||||
)
|
||||
|
||||
# schedule the application handler
|
||||
# its future is kept in self.websocket_tasks in case it
|
||||
@@ -985,7 +994,13 @@ class Sanic:
|
||||
if write_callback is None or isinstance(
|
||||
response, StreamingHTTPResponse
|
||||
):
|
||||
await stream_callback(response)
|
||||
if stream_callback:
|
||||
await stream_callback(response)
|
||||
else:
|
||||
# Should only end here IF it is an ASGI websocket.
|
||||
# TODO:
|
||||
# - Add exception handling
|
||||
pass
|
||||
else:
|
||||
write_callback(response)
|
||||
|
||||
@@ -1374,5 +1389,5 @@ class Sanic:
|
||||
# -------------------------------------------------------------------- #
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
asgi_app = ASGIApp(self, scope, receive, send)
|
||||
asgi_app = await ASGIApp.create(self, scope, receive, send)
|
||||
await asgi_app()
|
||||
|
||||
@@ -1,24 +1,50 @@
|
||||
from sanic.request import Request
|
||||
from multidict import CIMultiDict
|
||||
from sanic.response import StreamingHTTPResponse
|
||||
from typing import Any, Awaitable, Callable, MutableMapping, Union
|
||||
|
||||
from multidict import CIMultiDict
|
||||
|
||||
from sanic.request import Request
|
||||
from sanic.response import HTTPResponse, StreamingHTTPResponse
|
||||
from sanic.websocket import WebSocketConnection
|
||||
|
||||
|
||||
ASGIScope = MutableMapping[str, Any]
|
||||
ASGIMessage = MutableMapping[str, Any]
|
||||
ASGISend = Callable[[ASGIMessage], Awaitable[None]]
|
||||
ASGIReceive = Callable[[], Awaitable[ASGIMessage]]
|
||||
|
||||
class MockTransport:
|
||||
def __init__(self, scope):
|
||||
def __init__(self, scope: ASGIScope) -> None:
|
||||
self.scope = scope
|
||||
|
||||
def get_extra_info(self, info):
|
||||
def get_extra_info(self, info: str) -> Union[str, bool]:
|
||||
if info == "peername":
|
||||
return self.scope.get("server")
|
||||
elif info == "sslcontext":
|
||||
return self.scope.get("scheme") in ["https", "wss"]
|
||||
|
||||
def get_websocket_connection(self) -> WebSocketConnection:
|
||||
return self._websocket_connection
|
||||
|
||||
def create_websocket_connection(
|
||||
self,
|
||||
send: ASGISend,
|
||||
receive: ASGIReceive,
|
||||
) -> WebSocketConnection:
|
||||
self._websocket_connection = WebSocketConnection(send, receive)
|
||||
return self._websocket_connection
|
||||
|
||||
|
||||
class ASGIApp:
|
||||
def __init__(self, sanic_app, scope, receive, send):
|
||||
self.sanic_app = sanic_app
|
||||
self.receive = receive
|
||||
self.send = send
|
||||
def __init__(self) -> None:
|
||||
self.ws = None
|
||||
|
||||
@classmethod
|
||||
async def create(cls, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend) -> "ASGIApp":
|
||||
instance = cls()
|
||||
instance.sanic_app = sanic_app
|
||||
instance.receive = receive
|
||||
instance.send = send
|
||||
|
||||
url_bytes = scope.get("root_path", "") + scope["path"]
|
||||
url_bytes = url_bytes.encode("latin-1")
|
||||
url_bytes += scope["query_string"]
|
||||
@@ -28,18 +54,30 @@ class ASGIApp:
|
||||
for key, value in scope.get("headers", [])
|
||||
]
|
||||
)
|
||||
version = scope["http_version"]
|
||||
method = scope["method"]
|
||||
self.request = Request(
|
||||
url_bytes,
|
||||
headers,
|
||||
version,
|
||||
method,
|
||||
MockTransport(scope),
|
||||
sanic_app,
|
||||
|
||||
transport = MockTransport(scope)
|
||||
|
||||
if scope["type"] == "http":
|
||||
version = scope["http_version"]
|
||||
method = scope["method"]
|
||||
elif scope["type"] == "websocket":
|
||||
version = "1.1"
|
||||
method = "GET"
|
||||
|
||||
instance.ws = transport.create_websocket_connection(send, receive)
|
||||
await instance.ws.accept()
|
||||
else:
|
||||
pass
|
||||
# TODO:
|
||||
# - close connection
|
||||
|
||||
instance.request = Request(
|
||||
url_bytes, headers, version, method, transport, sanic_app
|
||||
)
|
||||
|
||||
async def read_body(self):
|
||||
return instance
|
||||
|
||||
async def read_body(self) -> bytes:
|
||||
"""
|
||||
Read and return the entire body from an incoming ASGI message.
|
||||
"""
|
||||
@@ -53,15 +91,16 @@ class ASGIApp:
|
||||
|
||||
return body
|
||||
|
||||
async def __call__(self):
|
||||
async def __call__(self) -> None:
|
||||
"""
|
||||
Handle the incoming request.
|
||||
"""
|
||||
self.request.body = await self.read_body()
|
||||
handler = self.sanic_app.handle_request
|
||||
await handler(self.request, None, self.stream_callback)
|
||||
callback = None if self.ws else self.stream_callback
|
||||
await handler(self.request, None, callback)
|
||||
|
||||
async def stream_callback(self, response):
|
||||
async def stream_callback(self, response: HTTPResponse) -> None:
|
||||
"""
|
||||
Write the response.
|
||||
"""
|
||||
|
||||
@@ -708,6 +708,8 @@ def serve(
|
||||
if debug:
|
||||
loop.set_debug(debug)
|
||||
|
||||
app.asgi = False
|
||||
|
||||
connections = connections if connections is not None else set()
|
||||
server = partial(
|
||||
protocol,
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
import typing
|
||||
import types
|
||||
import asyncio
|
||||
|
||||
from json import JSONDecodeError
|
||||
from socket import socket
|
||||
from urllib.parse import unquote, urljoin, urlsplit
|
||||
from urllib.parse import unquote, urlsplit
|
||||
|
||||
import httpcore
|
||||
import requests_async as requests
|
||||
import typing
|
||||
import websockets
|
||||
|
||||
from sanic.asgi import ASGIApp
|
||||
from sanic.exceptions import MethodNotSupported
|
||||
from sanic.log import logger
|
||||
from sanic.response import text
|
||||
|
||||
|
||||
HOST = "127.0.0.1"
|
||||
PORT = 42101
|
||||
|
||||
@@ -314,7 +319,7 @@ class TestASGIApp(ASGIApp):
|
||||
|
||||
|
||||
async def app_call_with_return(self, scope, receive, send):
|
||||
asgi_app = TestASGIApp(self, scope, receive, send)
|
||||
asgi_app = await TestASGIApp.create(self, scope, receive, send)
|
||||
return await asgi_app()
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union
|
||||
|
||||
from httptools import HttpParserUpgrade
|
||||
from websockets import ConnectionClosed # noqa
|
||||
from websockets import InvalidHandshake, WebSocketCommonProtocol, handshake
|
||||
@@ -6,6 +8,9 @@ from sanic.exceptions import InvalidUsage
|
||||
from sanic.server import HttpProtocol
|
||||
|
||||
|
||||
ASIMessage = MutableMapping[str, Any]
|
||||
|
||||
|
||||
class WebSocketProtocol(HttpProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -19,6 +24,7 @@ class WebSocketProtocol(HttpProtocol):
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.websocket = None
|
||||
self.app = None
|
||||
self.websocket_timeout = websocket_timeout
|
||||
self.websocket_max_size = websocket_max_size
|
||||
self.websocket_max_queue = websocket_max_queue
|
||||
@@ -103,3 +109,46 @@ class WebSocketProtocol(HttpProtocol):
|
||||
self.websocket.connection_made(request.transport)
|
||||
self.websocket.connection_open()
|
||||
return self.websocket
|
||||
|
||||
|
||||
class WebSocketConnection:
|
||||
|
||||
# TODO
|
||||
# - Implement ping/pong
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
send: Callable[[ASIMessage], Awaitable[None]],
|
||||
receive: Callable[[], Awaitable[ASIMessage]],
|
||||
) -> None:
|
||||
self._send = send
|
||||
self._receive = receive
|
||||
|
||||
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
|
||||
message = {"type": "websocket.send"}
|
||||
|
||||
try:
|
||||
data.decode()
|
||||
except AttributeError:
|
||||
message.update({"text": str(data)})
|
||||
else:
|
||||
message.update({"bytes": data})
|
||||
|
||||
await self._send(message)
|
||||
|
||||
async def recv(self, *args, **kwargs) -> Optional[str]:
|
||||
message = await self._receive()
|
||||
|
||||
if message["type"] == "websocket.receive":
|
||||
return message["text"]
|
||||
elif message["type"] == "websocket.disconnect":
|
||||
pass
|
||||
# await self._send({
|
||||
# "type": "websocket.close"
|
||||
# })
|
||||
|
||||
async def accept(self) -> None:
|
||||
await self._send({"type": "websocket.accept", "subprotocol": ""})
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user