Added support for routes with / in custom regexes and updated lru to use url and method

This commit is contained in:
Channel Cat 2016-10-20 11:33:28 +00:00
parent f510550888
commit d4e2d94816
2 changed files with 118 additions and 23 deletions

View File

@ -16,7 +16,11 @@ REGEX_TYPES = {
def url_hash(url): def url_hash(url):
return '/'.join(':' for s in url.split('/')) return url.count('/')
class RouteExists(Exception):
pass
class Router: class Router:
@ -31,16 +35,16 @@ class Router:
function provided Parameters can also have a type by appending :type to function provided Parameters can also have a type by appending :type to
the <parameter>. If no type is provided, a string is expected. A regular the <parameter>. If no type is provided, a string is expected. A regular
expression can also be passed in as the type expression can also be passed in as the type
TODO:
This probably needs optimization for larger sets of routes,
since it checks every route until it finds a match which is bad and
I should feel bad
""" """
routes = None routes_static = None
routes_dynamic = None
routes_always_check = None
def __init__(self): def __init__(self):
self.routes = defaultdict(list) self.routes_all = {}
self.routes_static = {}
self.routes_dynamic = defaultdict(list)
self.routes_always_check = []
def add(self, uri, methods, handler): def add(self, uri, methods, handler):
""" """
@ -52,12 +56,15 @@ class Router:
When executed, it should provide a response object. When executed, it should provide a response object.
:return: Nothing :return: Nothing
""" """
if uri in self.routes_all:
raise RouteExists("Route already registered: {}".format(uri))
# Dict for faster lookups of if method allowed # Dict for faster lookups of if method allowed
if methods: if methods:
methods = frozenset(methods) methods = frozenset(methods)
parameters = [] parameters = []
properties = {"unhashable": None}
def add_parameter(match): def add_parameter(match):
# We could receive NAME or NAME:PATTERN # We could receive NAME or NAME:PATTERN
@ -69,7 +76,13 @@ class Router:
default = (str, pattern) default = (str, pattern)
# Pull from pre-configured types # Pull from pre-configured types
_type, pattern = REGEX_TYPES.get(pattern, default) _type, pattern = REGEX_TYPES.get(pattern, default)
parameters.append(Parameter(name=name, cast=_type)) parameter = Parameter(name=name, cast=_type)
parameters.append(parameter)
# Mark the whole route as unhashable if it has the hash key in it
if re.search('(^|[^^]){1}/', pattern):
properties['unhashable'] = True
return '({})'.format(pattern) return '({})'.format(pattern)
pattern_string = re.sub(r'<(.+?)>', add_parameter, uri) pattern_string = re.sub(r'<(.+?)>', add_parameter, uri)
@ -79,11 +92,14 @@ class Router:
handler=handler, methods=methods, pattern=pattern, handler=handler, methods=methods, pattern=pattern,
parameters=parameters) parameters=parameters)
if parameters: self.routes_all[uri] = route
uri = url_hash(uri) if properties['unhashable']:
self.routes[uri].append(route) self.routes_always_check.append(route)
elif parameters:
self.routes_dynamic[url_hash(uri)].append(route)
else:
self.routes_static[uri] = route
@lru_cache(maxsize=Config.ROUTER_CACHE_SIZE)
def get(self, request): def get(self, request):
""" """
Gets a request handler based on the URL of the request, or raises an Gets a request handler based on the URL of the request, or raises an
@ -91,23 +107,40 @@ class Router:
:param request: Request object :param request: Request object
:return: handler, arguments, keyword arguments :return: handler, arguments, keyword arguments
""" """
route = None return self._get(request.url, request.method)
url = request.url
if url in self.routes: @lru_cache(maxsize=Config.ROUTER_CACHE_SIZE)
route = self.routes[url][0] def _get(self, url, method):
"""
Gets a request handler based on the URL of the request, or raises an
error. Internal method for caching.
:param url: Request URL
:param method: Request method
:return: handler, arguments, keyword arguments
"""
# Check against known static routes
route = self.routes_static.get(url)
if route:
match = route.pattern.match(url) match = route.pattern.match(url)
else: else:
for route in self.routes[url_hash(url)]: # Move on to testing all regex routes
for route in self.routes_dynamic[url_hash(url)]:
match = route.pattern.match(url) match = route.pattern.match(url)
if match: if match:
break break
else: else:
raise NotFound('Requested URL {} not found'.format(url)) # Lastly, check against all regex routes that cannot be hashed
for route in self.routes_always_check:
match = route.pattern.match(url)
if match:
break
else:
raise NotFound('Requested URL {} not found'.format(url))
if route.methods and request.method not in route.methods: if route.methods and method not in route.methods:
raise InvalidUsage( raise InvalidUsage(
'Method {} not allowed for URL {}'.format( 'Method {} not allowed for URL {}'.format(
request.method, url), status_code=405) method, url), status_code=405)
kwargs = {p.name: p.cast(value) kwargs = {p.name: p.cast(value)
for value, p for value, p

View File

@ -1,6 +1,8 @@
from json import loads as json_loads, dumps as json_dumps import pytest
from sanic import Sanic from sanic import Sanic
from sanic.response import json, text from sanic.response import text
from sanic.router import RouteExists
from sanic.utils import sanic_endpoint_test from sanic.utils import sanic_endpoint_test
@ -8,6 +10,24 @@ from sanic.utils import sanic_endpoint_test
# UTF-8 # UTF-8
# ------------------------------------------------------------ # # ------------------------------------------------------------ #
def test_static_routes():
app = Sanic('test_dynamic_route')
@app.route('/test')
async def handler1(request):
return text('OK1')
@app.route('/pizazz')
async def handler2(request):
return text('OK2')
request, response = sanic_endpoint_test(app, uri='/test')
assert response.text == 'OK1'
request, response = sanic_endpoint_test(app, uri='/pizazz')
assert response.text == 'OK2'
def test_dynamic_route(): def test_dynamic_route():
app = Sanic('test_dynamic_route') app = Sanic('test_dynamic_route')
@ -102,3 +122,45 @@ def test_dynamic_route_regex():
request, response = sanic_endpoint_test(app, uri='/folder/') request, response = sanic_endpoint_test(app, uri='/folder/')
assert response.status == 200 assert response.status == 200
def test_dynamic_route_unhashable():
app = Sanic('test_dynamic_route_unhashable')
@app.route('/folder/<unhashable:[A-Za-z0-9/]+>/end/')
async def handler(request, unhashable):
return text('OK')
request, response = sanic_endpoint_test(app, uri='/folder/test/asdf/end/')
assert response.status == 200
request, response = sanic_endpoint_test(app, uri='/folder/test///////end/')
assert response.status == 200
request, response = sanic_endpoint_test(app, uri='/folder/test/end/')
assert response.status == 200
request, response = sanic_endpoint_test(app, uri='/folder/test/nope/')
assert response.status == 404
def test_route_duplicate():
app = Sanic('test_dynamic_route')
with pytest.raises(RouteExists):
@app.route('/test')
async def handler1(request):
pass
@app.route('/test')
async def handler2(request):
pass
with pytest.raises(RouteExists):
@app.route('/test/<dynamic>/')
async def handler1(request, dynamic):
pass
@app.route('/test/<dynamic>/')
async def handler2(request, dynamic):
pass