diff --git a/sanic/app.py b/sanic/app.py index 05d99c08..23d404f1 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -544,6 +544,7 @@ class Sanic: # Fetch handler from router handler, args, kwargs, uri = self.router.get(request) + request.uri_template = uri if handler is None: raise ServerError( diff --git a/sanic/exceptions.py b/sanic/exceptions.py index aa1e0d4d..0133fd64 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -150,6 +150,16 @@ class InvalidUsage(SanicException): pass +@add_status_code(405) +class MethodNotSupported(SanicException): + def __init__(self, message, method, allowed_methods): + super().__init__(message) + self.headers = dict() + self.headers["Allow"] = ", ".join(allowed_methods) + if method in ['HEAD', 'PATCH', 'PUT', 'DELETE']: + self.headers['Content-Length'] = 0 + + @add_status_code(500) class ServerError(SanicException): pass @@ -167,8 +177,6 @@ class URLBuildError(ServerError): class FileNotFound(NotFound): - pass - def __init__(self, message, path, relative_url): super().__init__(message) self.path = path @@ -198,8 +206,6 @@ class HeaderNotFound(InvalidUsage): @add_status_code(416) class ContentRangeError(SanicException): - pass - def __init__(self, message, content_range): super().__init__(message) self.headers = { diff --git a/sanic/router.py b/sanic/router.py index 208e3772..2383b915 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -3,7 +3,7 @@ from collections import defaultdict, namedtuple from collections.abc import Iterable from functools import lru_cache -from sanic.exceptions import NotFound, InvalidUsage +from sanic.exceptions import NotFound, MethodNotSupported from sanic.views import CompositionView Route = namedtuple( @@ -350,6 +350,16 @@ class Router: except NotFound: return self._get(request.path, request.method, '') + def get_supported_methods(self, url): + """Get a list of supported methods for a url and optional host. + + :param url: URL string (including host) + :return: frozenset of supported methods + """ + route = self.routes_all.get(url) + # if methods are None then this logic will prevent an error + return getattr(route, 'methods', None) or frozenset() + @lru_cache(maxsize=ROUTER_CACHE_SIZE) def _get(self, url, method, host): """Get a request handler based on the URL of the request, or raises an @@ -362,9 +372,10 @@ class Router: url = host + url # Check against known static routes route = self.routes_static.get(url) - method_not_supported = InvalidUsage( - 'Method {} not allowed for URL {}'.format( - method, url), status_code=405) + method_not_supported = MethodNotSupported( + 'Method {} not allowed for URL {}'.format(method, url), + method=method, + allowed_methods=self.get_supported_methods(url)) if route: if route.methods and method not in route.methods: raise method_not_supported @@ -407,7 +418,7 @@ class Router: """ try: handler = self.get(request)[0] - except (NotFound, InvalidUsage): + except (NotFound, MethodNotSupported): return False if (hasattr(handler, 'view_class') and hasattr(handler.view_class, request.method.lower())): diff --git a/tests/test_response.py b/tests/test_response.py index 910c4e80..086b4e58 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -35,6 +35,25 @@ async def sample_streaming_fn(response): await asyncio.sleep(.001) response.write('bar') +def test_method_not_allowed(): + app = Sanic('method_not_allowed') + + @app.get('/') + async def test(request): + return response.json({'hello': 'world'}) + + request, response = app.test_client.head('/') + assert response.headers['Allow']== 'GET' + + @app.post('/') + async def test(request): + return response.json({'hello': 'world'}) + + request, response = app.test_client.head('/') + assert response.status == 405 + assert set(response.headers['Allow'].split(', ')) == set(['GET', 'POST']) + assert response.headers['Content-Length'] == '0' + @pytest.fixture def json_app(): @@ -254,4 +273,4 @@ def test_file_stream_head_response(file_name, static_file_directory): assert 'Content-Length' in response.headers assert int(response.headers[ 'Content-Length']) == len( - get_file_content(static_file_directory, file_name)) \ No newline at end of file + get_file_content(static_file_directory, file_name))