Merge branch 'use-route-name-for-method'

This commit is contained in:
lixxu 2017-08-22 14:02:56 +08:00
commit 145cdd5c1b
3 changed files with 69 additions and 33 deletions

View File

@ -266,4 +266,38 @@ app.blueprint(bp)
# then you need use `app.url_for('test_named_bp.get_handler')` # then you need use `app.url_for('test_named_bp.get_handler')`
# instead of `app.url_for('test_named_bp.handler')` # instead of `app.url_for('test_named_bp.handler')`
# different names can be used for same url with different methods
@app.get('/test', name='route_test')
def handler(request):
return text('OK')
@app.post('/test', name='route_post')
def handler2(request):
return text('OK POST')
@app.put('/test', name='route_put')
def handler3(request):
return text('OK PUT')
# below url are the same, you can use any of them
# '/test'
app.url_for('route_test')
# app.url_for('route_post')
# app.url_for('route_put')
# for same handler name with different methods
# you need specify the name (it's url_for issue)
@app.get('/get')
def handler(request):
return text('OK')
@app.post('/post', name='post_handler')
def handler(request):
return text('OK')
# then
# app.url_for('handler') == '/get'
# app.url_for('post_handler') == '/post'
``` ```

View File

@ -67,6 +67,7 @@ class Router:
def __init__(self): def __init__(self):
self.routes_all = {} self.routes_all = {}
self.routes_names = {}
self.routes_static = {} self.routes_static = {}
self.routes_dynamic = defaultdict(list) self.routes_dynamic = defaultdict(list)
self.routes_always_check = [] self.routes_always_check = []
@ -125,13 +126,12 @@ class Router:
# Add versions with and without trailing / # Add versions with and without trailing /
slash_is_missing = ( slash_is_missing = (
not uri[-1] == '/' not uri[-1] == '/' and not self.routes_all.get(uri + '/', False)
and not self.routes_all.get(uri + '/', False)
) )
without_slash_is_missing = ( without_slash_is_missing = (
uri[-1] == '/' uri[-1] == '/' and not
and not self.routes_all.get(uri[:-1], False) self.routes_all.get(uri[:-1], False) and not
and not uri == '/' uri == '/'
) )
# add version with trailing slash # add version with trailing slash
if slash_is_missing: if slash_is_missing:
@ -229,22 +229,26 @@ class Router:
else: else:
route = self.routes_all.get(uri) route = self.routes_all.get(uri)
# prefix the handler name with the blueprint name
# if available
if hasattr(handler, '__blueprintname__'):
handler_name = '{}.{}'.format(
handler.__blueprintname__, name or handler.__name__)
else:
handler_name = name or getattr(handler, '__name__', None)
if route: if route:
route = merge_route(route, methods, handler) route = merge_route(route, methods, handler)
else: else:
# prefix the handler name with the blueprint name
# if available
if hasattr(handler, '__blueprintname__'):
handler_name = '{}.{}'.format(
handler.__blueprintname__, name or handler.__name__)
else:
handler_name = name or getattr(handler, '__name__', None)
route = Route( route = Route(
handler=handler, methods=methods, pattern=pattern, handler=handler, methods=methods, pattern=pattern,
parameters=parameters, name=handler_name, uri=uri) parameters=parameters, name=handler_name, uri=uri)
self.routes_all[uri] = route self.routes_all[uri] = route
pairs = self.routes_names.get(handler_name)
if not (pairs and (pairs[0] + '/' == uri or uri + '/' == pairs[0])):
self.routes_names[handler_name] = (uri, route)
if properties['unhashable']: if properties['unhashable']:
self.routes_always_check.append(route) self.routes_always_check.append(route)
elif parameters: elif parameters:
@ -265,6 +269,11 @@ class Router:
uri = host + uri uri = host + uri
try: try:
route = self.routes_all.pop(uri) route = self.routes_all.pop(uri)
for handler_name, pairs in self.routes_names.items():
if pairs[0] == uri:
self.routes_names.pop(handler_name)
break
except KeyError: except KeyError:
raise RouteDoesNotExist("Route was not registered: {}".format(uri)) raise RouteDoesNotExist("Route was not registered: {}".format(uri))
@ -289,11 +298,7 @@ class Router:
if not view_name: if not view_name:
return (None, None) return (None, None)
for uri, route in self.routes_all.items(): return self.routes_names.get(view_name, (None, None))
if route.name == view_name:
return uri, route
return (None, None)
def get(self, request): def get(self, request):
"""Get a request handler based on the URL of the request, or raises an """Get a request handler based on the URL of the request, or raises an

View File

@ -7,7 +7,6 @@ import pytest
from sanic import Sanic from sanic import Sanic
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint
from sanic.response import text from sanic.response import text
from sanic.router import RouteExists, RouteDoesNotExist
from sanic.exceptions import URLBuildError from sanic.exceptions import URLBuildError
from sanic.constants import HTTP_METHODS from sanic.constants import HTTP_METHODS
@ -360,11 +359,7 @@ def test_overload_routes():
return text('OK1') return text('OK1')
@app.route('/overload', methods=['POST', 'PUT'], name='route_second') @app.route('/overload', methods=['POST', 'PUT'], name='route_second')
async def handler2(request): async def handler1(request):
return text('OK2')
@app.route('/overload2', methods=['POST', 'PUT'], name='route_third')
async def handler3(request):
return text('OK2') return text('OK2')
request, response = app.test_client.get(app.url_for('route_first')) request, response = app.test_client.get(app.url_for('route_first'))
@ -376,16 +371,18 @@ def test_overload_routes():
request, response = app.test_client.put(app.url_for('route_first')) request, response = app.test_client.put(app.url_for('route_first'))
assert response.text == 'OK2' assert response.text == 'OK2'
request, response = app.test_client.get(app.url_for('route_second'))
assert response.text == 'OK1'
request, response = app.test_client.post(app.url_for('route_second'))
assert response.text == 'OK2'
request, response = app.test_client.put(app.url_for('route_second'))
assert response.text == 'OK2'
assert app.router.routes_all['/overload'].name == 'route_first' assert app.router.routes_all['/overload'].name == 'route_first'
with pytest.raises(URLBuildError): with pytest.raises(URLBuildError):
app.url_for('handler1') app.url_for('handler1')
with pytest.raises(URLBuildError): assert app.url_for('route_first') == '/overload'
app.url_for('handler2') assert app.url_for('route_second') == app.url_for('route_first')
with pytest.raises(URLBuildError):
app.url_for('route_second')
assert app.url_for('route_third') == '/overload2'
with pytest.raises(URLBuildError):
app.url_for('handler3')