merge test conflicts

This commit is contained in:
Adam Hopkins 2021-04-20 00:53:42 +03:00
parent 93a0246c03
commit 492d6fd19d
16 changed files with 240 additions and 56 deletions

View File

@ -164,10 +164,12 @@ class ASGIApp:
Read and stream the body in chunks from an incoming ASGI message. Read and stream the body in chunks from an incoming ASGI message.
""" """
message = await self.transport.receive() message = await self.transport.receive()
body = message.get("body", b"")
if not message.get("more_body", False): if not message.get("more_body", False):
self.request_body = False self.request_body = False
if not body:
return None return None
return message.get("body", b"") return body
async def __aiter__(self): async def __aiter__(self):
while self.request_body: while self.request_body:

View File

@ -1,9 +1,13 @@
from collections.abc import MutableSequence from collections.abc import MutableSequence
from typing import List, Optional, Union from typing import TYPE_CHECKING, List, Optional, Union
import sanic import sanic
if TYPE_CHECKING:
from sanic.blueprints import Blueprint
class BlueprintGroup(MutableSequence): class BlueprintGroup(MutableSequence):
""" """
This class provides a mechanism to implement a Blueprint Group This class provides a mechanism to implement a Blueprint Group
@ -56,7 +60,12 @@ class BlueprintGroup(MutableSequence):
__slots__ = ("_blueprints", "_url_prefix", "_version", "_strict_slashes") __slots__ = ("_blueprints", "_url_prefix", "_version", "_strict_slashes")
def __init__(self, url_prefix=None, version=None, strict_slashes=None): def __init__(
self,
url_prefix: Optional[str] = None,
version: Optional[Union[int, str, float]] = None,
strict_slashes: Optional[bool] = None,
):
""" """
Create a new Blueprint Group Create a new Blueprint Group
@ -65,13 +74,13 @@ class BlueprintGroup(MutableSequence):
inherited by each of the Blueprint inherited by each of the Blueprint
:param strict_slashes: URL Strict slash behavior indicator :param strict_slashes: URL Strict slash behavior indicator
""" """
self._blueprints = [] self._blueprints: List[Blueprint] = []
self._url_prefix = url_prefix self._url_prefix = url_prefix
self._version = version self._version = version
self._strict_slashes = strict_slashes self._strict_slashes = strict_slashes
@property @property
def url_prefix(self) -> str: def url_prefix(self) -> Optional[Union[int, str, float]]:
""" """
Retrieve the URL prefix being used for the Current Blueprint Group Retrieve the URL prefix being used for the Current Blueprint Group

View File

@ -70,7 +70,7 @@ class Blueprint(BaseSanic):
name: str, name: str,
url_prefix: Optional[str] = None, url_prefix: Optional[str] = None,
host: Optional[str] = None, host: Optional[str] = None,
version: Optional[int] = None, version: Optional[Union[int, str, float]] = None,
strict_slashes: Optional[bool] = None, strict_slashes: Optional[bool] = None,
): ):
super().__init__() super().__init__()

View File

@ -82,6 +82,7 @@ class Http:
"request_max_size", "request_max_size",
"response", "response",
"response_func", "response_func",
"response_size",
"response_bytes_left", "response_bytes_left",
"upgrade_websocket", "upgrade_websocket",
] ]
@ -270,6 +271,7 @@ class Http:
size = len(data) size = len(data)
headers = res.headers headers = res.headers
status = res.status status = res.status
self.response_size = size
if not isinstance(status, int) or status < 200: if not isinstance(status, int) or status < 200:
raise RuntimeError(f"Invalid response status {status!r}") raise RuntimeError(f"Invalid response status {status!r}")
@ -424,7 +426,9 @@ class Http:
req, res = self.request, self.response req, res = self.request, self.response
extra = { extra = {
"status": getattr(res, "status", 0), "status": getattr(res, "status", 0),
"byte": getattr(self, "response_bytes_left", -1), "byte": getattr(
self, "response_bytes_left", getattr(self, "response_size", -1)
),
"host": "UNKNOWN", "host": "UNKNOWN",
"request": "nil", "request": "nil",
} }

View File

@ -45,7 +45,7 @@ class RouteMixin:
host: Optional[str] = None, host: Optional[str] = None,
strict_slashes: Optional[bool] = None, strict_slashes: Optional[bool] = None,
stream: bool = False, stream: bool = False,
version: Optional[int] = None, version: Optional[Union[int, str, float]] = None,
name: Optional[str] = None, name: Optional[str] = None,
ignore_body: bool = False, ignore_body: bool = False,
apply: bool = True, apply: bool = True,

View File

@ -33,7 +33,7 @@ class Router(BaseRouter):
return self.resolve( return self.resolve(
path=path, path=path,
method=method, method=method,
extra={"host": host}, extra={"host": host} if host else None,
) )
except RoutingNotFound as e: except RoutingNotFound as e:
raise NotFound("Requested URL {} not found".format(e.path)) raise NotFound("Requested URL {} not found".format(e.path))
@ -161,7 +161,7 @@ class Router(BaseRouter):
@property @property
def routes_all(self): def routes_all(self):
return self.routes return {route.parts: route for route in self.routes}
@property @property
def routes_static(self): def routes_static(self):

View File

@ -5,7 +5,7 @@ import asyncio
from inspect import isawaitable from inspect import isawaitable
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from sanic_routing import BaseRouter, Route # type: ignore from sanic_routing import BaseRouter, Route, RouteGroup # type: ignore
from sanic_routing.exceptions import NotFound # type: ignore from sanic_routing.exceptions import NotFound # type: ignore
from sanic_routing.utils import path_to_parts # type: ignore from sanic_routing.utils import path_to_parts # type: ignore
@ -20,17 +20,11 @@ RESERVED_NAMESPACES = (
class Signal(Route): class Signal(Route):
def get_handler(self, raw_path, method, _): ...
method = method or self.router.DEFAULT_METHOD
raw_path = raw_path.lstrip(self.router.delimiter)
try: class SignalGroup(RouteGroup):
return self.handlers[raw_path][method] ...
except (IndexError, KeyError):
raise self.router.method_handler_exception(
f"Method '{method}' not found on {self}",
method=method,
allowed_methods=set(self.methods[raw_path]),
)
class SignalRouter(BaseRouter): class SignalRouter(BaseRouter):
@ -38,6 +32,7 @@ class SignalRouter(BaseRouter):
super().__init__( super().__init__(
delimiter=".", delimiter=".",
route_class=Signal, route_class=Signal,
group_class=SignalGroup,
stacking=True, stacking=True,
) )
self.ctx.loop = None self.ctx.loop = None
@ -49,7 +44,13 @@ class SignalRouter(BaseRouter):
): ):
extra = condition or {} extra = condition or {}
try: try:
return self.resolve(f".{event}", extra=extra) group, param_basket = self.find_route(
f".{event}",
self.DEFAULT_METHOD,
self,
{"__params__": {}},
extra=extra,
)
except NotFound: except NotFound:
message = "Could not find signal %s" message = "Could not find signal %s"
terms: List[Union[str, Optional[Dict[str, str]]]] = [event] terms: List[Union[str, Optional[Dict[str, str]]]] = [event]
@ -58,15 +59,19 @@ class SignalRouter(BaseRouter):
terms.append(extra) terms.append(extra)
raise NotFound(message % tuple(terms)) raise NotFound(message % tuple(terms))
params = param_basket.pop("__params__")
return group, [route.handler for route in group], params
async def _dispatch( async def _dispatch(
self, self,
event: str, event: str,
context: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None,
condition: Optional[Dict[str, str]] = None, condition: Optional[Dict[str, str]] = None,
) -> None: ) -> None:
signal, handlers, params = self.get(event, condition=condition) group, handlers, params = self.get(event, condition=condition)
signal_event = signal.ctx.event events = [signal.ctx.event for signal in group]
for signal_event in events:
signal_event.set() signal_event.set()
if context: if context:
params.update(context) params.update(context)
@ -78,6 +83,7 @@ class SignalRouter(BaseRouter):
if isawaitable(maybe_coroutine): if isawaitable(maybe_coroutine):
await maybe_coroutine await maybe_coroutine
finally: finally:
for signal_event in events:
signal_event.clear() signal_event.clear()
async def dispatch( async def dispatch(
@ -116,7 +122,7 @@ class SignalRouter(BaseRouter):
handler, handler,
requirements=condition, requirements=condition,
name=name, name=name,
overwrite=True, append=True,
) # type: ignore ) # type: ignore
def finalize(self, do_compile: bool = True): def finalize(self, do_compile: bool = True):
@ -125,7 +131,7 @@ class SignalRouter(BaseRouter):
except RuntimeError: except RuntimeError:
raise RuntimeError("Cannot finalize signals outside of event loop") raise RuntimeError("Cannot finalize signals outside of event loop")
for signal in self.routes.values(): for signal in self.routes:
signal.ctx.event = asyncio.Event() signal.ctx.event = asyncio.Event()
return super().finalize(do_compile=do_compile) return super().finalize(do_compile=do_compile)

View File

@ -83,7 +83,7 @@ ujson = "ujson>=1.35" + env_dependency
uvloop = "uvloop>=0.5.3" + env_dependency uvloop = "uvloop>=0.5.3" + env_dependency
requirements = [ requirements = [
"sanic-routing", "sanic-routing>=0.6.0",
"httptools>=0.0.10", "httptools>=0.0.10",
uvloop, uvloop,
ujson, ujson,

View File

@ -80,6 +80,12 @@ def test_dont_load_env():
del environ["SANIC_TEST_ANSWER"] del environ["SANIC_TEST_ANSWER"]
@pytest.mark.parametrize("load_env", [None, False, "", "MYAPP_"])
def test_load_env_deprecation(load_env):
with pytest.warns(DeprecationWarning, match=r"21\.12"):
_ = Sanic(name=__name__, load_env=load_env)
def test_load_env_prefix(): def test_load_env_prefix():
environ["MYAPP_TEST_ANSWER"] = "42" environ["MYAPP_TEST_ANSWER"] = "42"
app = Sanic(name=__name__, load_env="MYAPP_") app = Sanic(name=__name__, load_env="MYAPP_")
@ -87,6 +93,14 @@ def test_load_env_prefix():
del environ["MYAPP_TEST_ANSWER"] del environ["MYAPP_TEST_ANSWER"]
@pytest.mark.parametrize("env_prefix", [None, ""])
def test_empty_load_env_prefix(env_prefix):
environ["SANIC_TEST_ANSWER"] = "42"
app = Sanic(name=__name__, env_prefix=env_prefix)
assert getattr(app.config, "TEST_ANSWER", None) is None
del environ["SANIC_TEST_ANSWER"]
def test_load_env_prefix_float_values(): def test_load_env_prefix_float_values():
environ["MYAPP_TEST_ROI"] = "2.3" environ["MYAPP_TEST_ROI"] = "2.3"
app = Sanic(name=__name__, load_env="MYAPP_") app = Sanic(name=__name__, load_env="MYAPP_")

View File

@ -209,13 +209,13 @@ def test_named_static_routes():
return text("OK2") return text("OK2")
assert app.router.routes_all[("test",)].name == "app.route_test" assert app.router.routes_all[("test",)].name == "app.route_test"
assert app.router.routes_static[("test",)].name == "app.route_test" assert app.router.routes_static[("test",)][0].name == "app.route_test"
assert app.url_for("route_test") == "/test" assert app.url_for("route_test") == "/test"
with pytest.raises(URLBuildError): with pytest.raises(URLBuildError):
app.url_for("handler1") app.url_for("handler1")
assert app.router.routes_all[("pizazz",)].name == "app.route_pizazz" assert app.router.routes_all[("pizazz",)].name == "app.route_pizazz"
assert app.router.routes_static[("pizazz",)].name == "app.route_pizazz" assert app.router.routes_static[("pizazz",)][0].name == "app.route_pizazz"
assert app.url_for("route_pizazz") == "/pizazz" assert app.url_for("route_pizazz") == "/pizazz"
with pytest.raises(URLBuildError): with pytest.raises(URLBuildError):
app.url_for("handler2") app.url_for("handler2")
@ -347,13 +347,13 @@ def test_static_add_named_route():
app.add_route(handler2, "/test2", name="route_test2") app.add_route(handler2, "/test2", name="route_test2")
assert app.router.routes_all[("test",)].name == "app.route_test" assert app.router.routes_all[("test",)].name == "app.route_test"
assert app.router.routes_static[("test",)].name == "app.route_test" assert app.router.routes_static[("test",)][0].name == "app.route_test"
assert app.url_for("route_test") == "/test" assert app.url_for("route_test") == "/test"
with pytest.raises(URLBuildError): with pytest.raises(URLBuildError):
app.url_for("handler1") app.url_for("handler1")
assert app.router.routes_all[("test2",)].name == "app.route_test2" assert app.router.routes_all[("test2",)].name == "app.route_test2"
assert app.router.routes_static[("test2",)].name == "app.route_test2" assert app.router.routes_static[("test2",)][0].name == "app.route_test2"
assert app.url_for("route_test2") == "/test2" assert app.url_for("route_test2") == "/test2"
with pytest.raises(URLBuildError): with pytest.raises(URLBuildError):
app.url_for("handler2") app.url_for("handler2")

View File

@ -104,7 +104,7 @@ def test_route_assigned_to_request(app):
return response.empty() return response.empty()
request, _ = app.test_client.get("/") request, _ = app.test_client.get("/")
assert request.route is list(app.router.routes.values())[0] assert request.route is list(app.router.routes)[0]
def test_protocol_attribute(app): def test_protocol_attribute(app):

View File

@ -253,6 +253,31 @@ async def test_empty_json_asgi(app):
assert response.body == b"null" assert response.body == b"null"
def test_echo_json(app):
@app.post("/")
async def handler(request):
return json(request.json)
data = {"foo": "bar"}
request, response = app.test_client.post("/", json=data)
assert response.status == 200
assert response.json == data
@pytest.mark.asyncio
async def test_echo_json_asgi(app):
@app.post("/")
async def handler(request):
return json(request.json)
data = {"foo": "bar"}
request, response = await app.asgi_client.post("/", json=data)
assert response.status == 200
assert response.json == data
def test_invalid_json(app): def test_invalid_json(app):
@app.post("/") @app.post("/")
async def handler(request): async def handler(request):
@ -292,6 +317,17 @@ def test_query_string(app):
assert request.args.get("test3", default="My value") == "My value" assert request.args.get("test3", default="My value") == "My value"
def test_popped_stays_popped(app):
@app.route("/")
async def handler(request):
return text("OK")
request, response = app.test_client.get("/", params=[("test1", "1")])
assert request.args.pop("test1") == ["1"]
assert "test1" not in request.args
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_string_asgi(app): async def test_query_string_asgi(app):
@app.route("/") @app.route("/")
@ -2159,3 +2195,72 @@ def test_safe_method_with_body(app):
assert request.body == data.encode("utf-8") assert request.body == data.encode("utf-8")
assert request.json.get("test") == "OK" assert request.json.get("test") == "OK"
assert response.body == b"OK" assert response.body == b"OK"
def test_conflicting_body_methods_overload(app):
@app.put("/")
@app.put("/p/")
@app.put("/p/<foo>")
async def put(request, foo=None):
return json(
{"name": request.route.name, "body": str(request.body), "foo": foo}
)
@app.delete("/p/<foo>")
async def delete(request, foo):
return json(
{"name": request.route.name, "body": str(request.body), "foo": foo}
)
payload = {"test": "OK"}
data = str(json_dumps(payload).encode())
_, response = app.test_client.put("/", json=payload)
assert response.status == 200
assert response.json == {
"name": "test_conflicting_body_methods_overload.put",
"foo": None,
"body": data,
}
_, response = app.test_client.put("/p", json=payload)
assert response.status == 200
assert response.json == {
"name": "test_conflicting_body_methods_overload.put",
"foo": None,
"body": data,
}
_, response = app.test_client.put("/p/test", json=payload)
assert response.status == 200
assert response.json == {
"name": "test_conflicting_body_methods_overload.put",
"foo": "test",
"body": data,
}
_, response = app.test_client.delete("/p/test")
assert response.status == 200
assert response.json == {
"name": "test_conflicting_body_methods_overload.delete",
"foo": "test",
"body": str("".encode()),
}
def test_handler_overload(app):
@app.get(
"/long/sub/route/param_a/<param_a:string>/param_b/<param_b:string>"
)
@app.post("/long/sub/route/")
def handler(request, **kwargs):
return json(kwargs)
_, response = app.test_client.get(
"/long/sub/route/param_a/foo/param_b/bar"
)
assert response.status == 200
assert response.json == {
"param_a": "foo",
"param_b": "bar",
}
_, response = app.test_client.post("/long/sub/route")
assert response.status == 200
assert response.json == {}

View File

@ -65,7 +65,9 @@ def test_method_not_allowed():
} }
request, response = app.test_client.post("/") request, response = app.test_client.post("/")
assert set(response.headers["Allow"].split(", ")) == {"GET", "HEAD"} assert set(response.headers["Allow"].split(", ")) == {
"GET",
}
app.router.reset() app.router.reset()
@ -78,7 +80,6 @@ def test_method_not_allowed():
assert set(response.headers["Allow"].split(", ")) == { assert set(response.headers["Allow"].split(", ")) == {
"GET", "GET",
"POST", "POST",
"HEAD",
} }
assert response.headers["Content-Length"] == "0" assert response.headers["Content-Length"] == "0"
@ -87,7 +88,6 @@ def test_method_not_allowed():
assert set(response.headers["Allow"].split(", ")) == { assert set(response.headers["Allow"].split(", ")) == {
"GET", "GET",
"POST", "POST",
"HEAD",
} }
assert response.headers["Content-Length"] == "0" assert response.headers["Content-Length"] == "0"

View File

@ -543,9 +543,6 @@ def test_dynamic_route_regex(app):
async def handler(request, folder_id): async def handler(request, folder_id):
return text("OK") return text("OK")
app.router.finalize()
print(app.router.find_route_src)
request, response = app.test_client.get("/folder/test") request, response = app.test_client.get("/folder/test")
assert response.status == 200 assert response.status == 200
@ -587,6 +584,9 @@ def test_dynamic_route_path(app):
async def handler(request, path): async def handler(request, path):
return text("OK") return text("OK")
app.router.finalize()
print(app.router.find_route_src)
request, response = app.test_client.get("/path/1/info") request, response = app.test_client.get("/path/1/info")
assert response.status == 200 assert response.status == 200
@ -1008,14 +1008,8 @@ def test_unmergeable_overload_routes(app):
async def handler2(request): async def handler2(request):
return text("OK1") return text("OK1")
assert ( assert len(app.router.static_routes) == 1
len( assert len(app.router.static_routes[("overload_whole",)].methods) == 3
dict(list(app.router.static_routes.values())[0].handlers)[
"overload_whole"
]
)
== 3
)
request, response = app.test_client.get("/overload_whole") request, response = app.test_client.get("/overload_whole")
assert response.text == "OK1" assert response.text == "OK1"

View File

@ -28,7 +28,8 @@ def test_add_signal_decorator(app):
async def async_signal(*_): async def async_signal(*_):
... ...
assert len(app.signal_router.routes) == 1 assert len(app.signal_router.routes) == 2
assert len(app.signal_router.dynamic_routes) == 1
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -79,13 +80,13 @@ async def test_dispatch_signal_triggers_triggers_event(app):
def sync_signal(*args): def sync_signal(*args):
nonlocal app nonlocal app
nonlocal counter nonlocal counter
signal, *_ = app.signal_router.get("foo.bar.baz") group, *_ = app.signal_router.get("foo.bar.baz")
for signal in group:
counter += signal.ctx.event.is_set() counter += signal.ctx.event.is_set()
app.signal_router.finalize() app.signal_router.finalize()
await app.dispatch("foo.bar.baz") await app.dispatch("foo.bar.baz")
signal, *_ = app.signal_router.get("foo.bar.baz")
assert counter == 1 assert counter == 1
@ -224,7 +225,7 @@ async def test_dispatch_signal_triggers_event_on_bp(app):
app.blueprint(bp) app.blueprint(bp)
app.signal_router.finalize() app.signal_router.finalize()
signal, *_ = app.signal_router.get( signal_group, *_ = app.signal_router.get(
"foo.bar.baz", condition={"blueprint": "bp"} "foo.bar.baz", condition={"blueprint": "bp"}
) )
@ -233,6 +234,7 @@ async def test_dispatch_signal_triggers_event_on_bp(app):
assert isawaitable(waiter) assert isawaitable(waiter)
fut = asyncio.ensure_future(do_wait()) fut = asyncio.ensure_future(do_wait())
for signal in signal_group:
signal.ctx.event.set() signal.ctx.event.set()
await fut await fut

View File

@ -454,3 +454,51 @@ def test_nested_dir(app, static_file_directory):
assert response.status == 200 assert response.status == 200
assert response.text == "foo\n" assert response.text == "foo\n"
def test_stack_trace_on_not_found(app, static_file_directory, caplog):
app.static("/static", static_file_directory)
with caplog.at_level(logging.INFO):
_, response = app.test_client.get("/static/non_existing_file.file")
counter = Counter([r[1] for r in caplog.record_tuples])
assert response.status == 404
assert counter[logging.INFO] == 5
assert counter[logging.ERROR] == 1
def test_no_stack_trace_on_not_found(app, static_file_directory, caplog):
app.static("/static", static_file_directory)
@app.exception(FileNotFound)
async def file_not_found(request, exception):
return text(f"No file: {request.path}", status=404)
with caplog.at_level(logging.INFO):
_, response = app.test_client.get("/static/non_existing_file.file")
counter = Counter([r[1] for r in caplog.record_tuples])
assert response.status == 404
assert counter[logging.INFO] == 5
assert logging.ERROR not in counter
assert response.text == "No file: /static/non_existing_file.file"
def test_multiple_statics(app, static_file_directory):
app.static("/file", get_file_path(static_file_directory, "test.file"))
app.static("/png", get_file_path(static_file_directory, "python.png"))
_, response = app.test_client.get("/file")
assert response.status == 200
assert response.body == get_file_content(
static_file_directory, "test.file"
)
_, response = app.test_client.get("/png")
assert response.status == 200
assert response.body == get_file_content(
static_file_directory, "python.png"
)