diff --git a/sanic/blueprints.py b/sanic/blueprints.py index c9a4b8ac..0c14f4bc 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -35,6 +35,9 @@ class Blueprint: # Routes for future in self.routes: + # attach the blueprint name to the handler so that it can be + # prefixed properly in the router + future.handler.__blueprintname__ = self.name # Prepend the blueprint URI prefix if available uri = url_prefix + future.uri if url_prefix else future.uri app.route( diff --git a/sanic/exceptions.py b/sanic/exceptions.py index d986cd08..201edbf6 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -120,6 +120,10 @@ class ServerError(SanicException): status_code = 500 +class URLBuildError(SanicException): + status_code = 500 + + class FileNotFound(NotFound): status_code = 404 diff --git a/sanic/router.py b/sanic/router.py index 5ed21136..25b1b9ee 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -4,8 +4,8 @@ from functools import lru_cache from .exceptions import NotFound, InvalidUsage from .views import CompositionView -Route = namedtuple('Route', ['handler', 'methods', 'pattern', 'parameters']) -Parameter = namedtuple('Parameter', ['name', 'cast']) +Route = namedtuple('Route', ['handler', 'methods', 'pattern', 'parameters', 'name']) +Parameter = namedtuple('Parameter', ['name', 'cast', 'pattern']) REGEX_TYPES = { 'string': (str, r'[^/]+'), @@ -59,6 +59,7 @@ class Router: routes_static = None routes_dynamic = None routes_always_check = None + parameter_pattern = re.compile(r'<(.+?)>') def __init__(self): self.routes_all = {} @@ -67,6 +68,19 @@ class Router: self.routes_always_check = [] self.hosts = None + def parse_parameter_string(self, parameter_string): + # We could receive NAME or NAME:PATTERN + name = parameter_string + pattern = 'string' + if ':' in parameter_string: + name, pattern = parameter_string.split(':', 1) + + default = (str, pattern) + # Pull from pre-configured types + _type, pattern = REGEX_TYPES.get(pattern, default) + + return name, _type, pattern + def add(self, uri, methods, handler, host=None): """ Adds a handler to the route list @@ -106,14 +120,13 @@ class Router: def add_parameter(match): # We could receive NAME or NAME:PATTERN name = match.group(1) - pattern = 'string' - if ':' in name: - name, pattern = name.split(':', 1) + name, _type, pattern = self.parse_parameter_string(name) - default = (str, pattern) - # Pull from pre-configured types - _type, pattern = REGEX_TYPES.get(pattern, default) - parameter = Parameter(name=name, cast=_type) + # store a regex for matching on a specific parameter + # this will be useful for URL building + specific_parameter_pattern = '^{}$'.format(pattern) + parameter = Parameter( + name=name, cast=_type, pattern=specific_parameter_pattern) parameters.append(parameter) # Mark the whole route as unhashable if it has the hash key in it @@ -125,7 +138,7 @@ class Router: return '({})'.format(pattern) - pattern_string = re.sub(r'<(.+?)>', add_parameter, uri) + pattern_string = re.sub(self.parameter_pattern, add_parameter, uri) pattern = re.compile(r'^{}$'.format(pattern_string)) def merge_route(route, methods, handler): @@ -169,9 +182,17 @@ class Router: if route: route = merge_route(route, methods, handler) else: + # prefix the handler name with the blueprint name + # if available + if hasattr(handler, '__blueprintname__'): + handler_name = '{}.{}'.format( + handler.__blueprintname__, handler.__name__) + else: + handler_name = handler.__name__ + route = Route( handler=handler, methods=methods, pattern=pattern, - parameters=parameters) + parameters=parameters, name=handler_name) self.routes_all[uri] = route if properties['unhashable']: @@ -208,6 +229,14 @@ class Router: if clean_cache: self._get.cache_clear() + @lru_cache(maxsize=ROUTER_CACHE_SIZE) + def find_route_by_view_name(self, view_name): + for uri, route in self.routes_all.items(): + if route.name == view_name: + return uri, route + + return (None, None) + def get(self, request): """ Gets a request handler based on the URL of the request, or raises an diff --git a/sanic/sanic.py b/sanic/sanic.py index 8cfe08dd..11b104ac 100644 --- a/sanic/sanic.py +++ b/sanic/sanic.py @@ -3,13 +3,15 @@ from asyncio import get_event_loop from collections import deque from functools import partial from inspect import isawaitable, stack, getmodulename +import re from traceback import format_exc +from urllib.parse import urlencode, urlunparse import warnings from .config import Config from .constants import HTTP_METHODS from .exceptions import Handler -from .exceptions import ServerError +from .exceptions import ServerError, URLBuildError from .log import log from .response import HTTPResponse from .router import Router @@ -192,6 +194,63 @@ class Sanic: DeprecationWarning) return self.blueprint(*args, **kwargs) + def url_for(self, view_name: str, **kwargs): + uri, route = self.router.find_route_by_view_name(view_name) + + if not uri or not route: + raise URLBuildError( + 'Endpoint with name `{}` was not found'.format( + view_name)) + + out = uri + matched_params = re.findall( + self.router.parameter_pattern, uri) + + for match in matched_params: + name, _type, pattern = self.router.parse_parameter_string( + match) + specific_pattern = '^{}$'.format(pattern) + + supplied_param = None + if kwargs.get(name): + supplied_param = kwargs.get(name) + del kwargs[name] + else: + raise URLBuildError( + 'Required parameter `{}` was not passed to url_for'.format( + name)) + + supplied_param = str(supplied_param) + passes_pattern = re.match(specific_pattern, supplied_param) + + if not passes_pattern: + if _type != str: + msg = ( + 'Value "{}" for parameter `{}` does not ' + 'match pattern for type `{}`: {}'.format( + supplied_param, name, _type.__name__, pattern)) + else: + msg = ( + 'Value "{}" for parameter `{}` ' + 'does not satisfy pattern {}'.format( + supplied_param, name, pattern)) + raise URLBuildError(msg) + + replacement_regex = '(<{}.*?>)'.format(name) + + out = re.sub( + replacement_regex, supplied_param, out) + + # parse the remainder of the keyword arguments into a querystring + if kwargs: + query_string = urlencode(kwargs) + out = urlunparse(( + '', '', out, + '', query_string, '' + )) + + return out + # -------------------------------------------------------------------- # # Request Handling # -------------------------------------------------------------------- # diff --git a/sanic/views.py b/sanic/views.py index 407ba136..5d8e9d40 100644 --- a/sanic/views.py +++ b/sanic/views.py @@ -64,6 +64,7 @@ class HTTPMethodView: view.view_class = cls view.__doc__ = cls.__doc__ view.__module__ = cls.__module__ + view.__name__ = cls.__name__ return view diff --git a/tests/test_url_building.py b/tests/test_url_building.py new file mode 100644 index 00000000..c8209557 --- /dev/null +++ b/tests/test_url_building.py @@ -0,0 +1,261 @@ +import pytest as pytest +from urllib.parse import urlsplit, parse_qsl + +from sanic import Sanic +from sanic.response import text +from sanic.views import HTTPMethodView, CompositionView +from sanic.blueprints import Blueprint +from sanic.utils import sanic_endpoint_test +from sanic.exceptions import URLBuildError + +import string + + +def _generate_handlers_from_names(app, l): + for name in l: + # this is the easiest way to generate functions with dynamic names + exec('@app.route(name)\ndef {}(request):\n\treturn text("{}")'.format( + name, name)) + + +@pytest.fixture +def simple_app(): + app = Sanic('simple_app') + handler_names = list(string.ascii_letters) + + _generate_handlers_from_names(app, handler_names) + + return app + + +def test_simple_url_for_getting(simple_app): + for letter in string.ascii_letters: + url = simple_app.url_for(letter) + + assert url == '/{}'.format(letter) + request, response = sanic_endpoint_test( + simple_app, uri=url) + assert response.status == 200 + assert response.text == letter + + +def test_fails_if_endpoint_not_found(): + app = Sanic('fail_url_build') + + @app.route('/fail') + def fail(): + return text('this should fail') + + with pytest.raises(URLBuildError) as e: + app.url_for('passes') + + assert str(e.value) == 'Endpoint with name `passes` was not found' + + +def test_fails_url_build_if_param_not_passed(): + url = '/' + + for letter in string.ascii_letters: + url += '<{}>/'.format(letter) + + app = Sanic('fail_url_build') + + @app.route(url) + def fail(): + return text('this should fail') + + fail_args = list(string.ascii_letters) + fail_args.pop() + + fail_kwargs = {l: l for l in fail_args} + + with pytest.raises(URLBuildError) as e: + app.url_for('fail', **fail_kwargs) + + assert 'Required parameter `Z` was not passed to url_for' in str(e.value) + + +COMPLEX_PARAM_URL = ( + '///' + '//') +PASSING_KWARGS = { + 'foo': 4, 'four_letter_string': 'woof', + 'two_letter_string': 'ba', 'normal_string': 'normal', + 'some_number': '1.001'} +EXPECTED_BUILT_URL = '/4/woof/ba/normal/1.001' + + +def test_fails_with_int_message(): + app = Sanic('fail_url_build') + + @app.route(COMPLEX_PARAM_URL) + def fail(): + return text('this should fail') + + failing_kwargs = dict(PASSING_KWARGS) + failing_kwargs['foo'] = 'not_int' + + with pytest.raises(URLBuildError) as e: + app.url_for('fail', **failing_kwargs) + + expected_error = ( + 'Value "not_int" for parameter `foo` ' + 'does not match pattern for type `int`: \d+') + assert str(e.value) == expected_error + + +def test_fails_with_two_letter_string_message(): + app = Sanic('fail_url_build') + + @app.route(COMPLEX_PARAM_URL) + def fail(): + return text('this should fail') + + failing_kwargs = dict(PASSING_KWARGS) + failing_kwargs['two_letter_string'] = 'foobar' + + with pytest.raises(URLBuildError) as e: + app.url_for('fail', **failing_kwargs) + + expected_error = ( + 'Value "foobar" for parameter `two_letter_string` ' + 'does not satisfy pattern [A-z]{2}') + + assert str(e.value) == expected_error + + +def test_fails_with_number_message(): + app = Sanic('fail_url_build') + + @app.route(COMPLEX_PARAM_URL) + def fail(): + return text('this should fail') + + failing_kwargs = dict(PASSING_KWARGS) + failing_kwargs['some_number'] = 'foo' + + with pytest.raises(URLBuildError) as e: + app.url_for('fail', **failing_kwargs) + + expected_error = ( + 'Value "foo" for parameter `some_number` ' + 'does not match pattern for type `float`: [0-9\\\\.]+') + + assert str(e.value) == expected_error + + +def test_adds_other_supplied_values_as_query_string(): + app = Sanic('passes') + + @app.route(COMPLEX_PARAM_URL) + def passes(): + return text('this should pass') + + new_kwargs = dict(PASSING_KWARGS) + new_kwargs['added_value_one'] = 'one' + new_kwargs['added_value_two'] = 'two' + + url = app.url_for('passes', **new_kwargs) + + query = dict(parse_qsl(urlsplit(url).query)) + + assert query['added_value_one'] == 'one' + assert query['added_value_two'] == 'two' + + +@pytest.fixture +def blueprint_app(): + app = Sanic('blueprints') + + first_print = Blueprint('first', url_prefix='/first') + second_print = Blueprint('second', url_prefix='/second') + + @first_print.route('/foo') + def foo(): + return text('foo from first') + + @first_print.route('/foo/') + def foo_with_param(request, param): + return text( + 'foo from first : {}'.format(param)) + + @second_print.route('/foo') # noqa + def foo(): + return text('foo from second') + + @second_print.route('/foo/') # noqa + def foo_with_param(request, param): + return text( + 'foo from second : {}'.format(param)) + + app.blueprint(first_print) + app.blueprint(second_print) + + return app + + +def test_blueprints_are_named_correctly(blueprint_app): + first_url = blueprint_app.url_for('first.foo') + assert first_url == '/first/foo' + + second_url = blueprint_app.url_for('second.foo') + assert second_url == '/second/foo' + + +def test_blueprints_work_with_params(blueprint_app): + first_url = blueprint_app.url_for('first.foo_with_param', param='bar') + assert first_url == '/first/foo/bar' + + second_url = blueprint_app.url_for('second.foo_with_param', param='bar') + assert second_url == '/second/foo/bar' + + +@pytest.fixture +def methodview_app(): + app = Sanic('methodview') + + class ViewOne(HTTPMethodView): + def get(self, request): + return text('I am get method') + + def post(self, request): + return text('I am post method') + + def put(self, request): + return text('I am put method') + + def patch(self, request): + return text('I am patch method') + + def delete(self, request): + return text('I am delete method') + + app.add_route(ViewOne.as_view('view_one'), '/view_one') + + class ViewTwo(HTTPMethodView): + def get(self, request): + return text('I am get method') + + def post(self, request): + return text('I am post method') + + def put(self, request): + return text('I am put method') + + def patch(self, request): + return text('I am patch method') + + def delete(self, request): + return text('I am delete method') + + app.add_route(ViewTwo.as_view(), '/view_two') + + return app + + +def test_methodview_naming(methodview_app): + viewone_url = methodview_app.url_for('ViewOne') + viewtwo_url = methodview_app.url_for('ViewTwo') + + assert viewone_url == '/view_one' + assert viewtwo_url == '/view_two'