diff --git a/sanic/app.py b/sanic/app.py index 91179af6..32eddb25 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -543,6 +543,7 @@ class Sanic: # Fetch handler from router handler, args, kwargs, uri = self.router.get(request) + request.uri_template = uri if handler is None: raise ServerError( diff --git a/sanic/exceptions.py b/sanic/exceptions.py index aa1e0d4d..5bf5f2de 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -150,6 +150,16 @@ class InvalidUsage(SanicException): pass +@add_status_code(405) +class MethodNotSupported(SanicException): + def __init__(self, message, method, allowed_methods): + super().__init__(message) + self.headers = dict() + self.headers["Allow"] = ", ".join(allowed_methods) + if method in ['HEAD', 'PATCH', 'PUT', 'DELETE']: + self.headers['Content-Length'] = 0 + + @add_status_code(500) class ServerError(SanicException): pass diff --git a/sanic/router.py b/sanic/router.py index b601622c..4f5470c0 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -3,7 +3,7 @@ from collections import defaultdict, namedtuple from collections.abc import Iterable from functools import lru_cache -from sanic.exceptions import NotFound, InvalidUsage +from sanic.exceptions import NotFound, InvalidUsage, MethodNotSupported from sanic.views import CompositionView Route = namedtuple( @@ -352,6 +352,15 @@ class Router: except NotFound: return self._get(request.path, request.method, '') + def get_supported_methods(self, url): + """Get a list of supported methods for a url and optional host. + + :param url: URL string (including host) + :return: frozenset of supported methods + """ + route = self.routes_all.get(url) + return getattr(route, 'methods', frozenset()) + @lru_cache(maxsize=ROUTER_CACHE_SIZE) def _get(self, url, method, host): """Get a request handler based on the URL of the request, or raises an @@ -364,9 +373,10 @@ class Router: url = host + url # Check against known static routes route = self.routes_static.get(url) - method_not_supported = InvalidUsage( - 'Method {} not allowed for URL {}'.format( - method, url), status_code=405) + method_not_supported = MethodNotSupported( + 'Method {} not allowed for URL {}'.format(method, url), + method=method, + allowed_methods=self.get_supported_methods(url)) if route: if route.methods and method not in route.methods: raise method_not_supported diff --git a/tests/test_response.py b/tests/test_response.py index 910c4e80..086b4e58 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -35,6 +35,25 @@ async def sample_streaming_fn(response): await asyncio.sleep(.001) response.write('bar') +def test_method_not_allowed(): + app = Sanic('method_not_allowed') + + @app.get('/') + async def test(request): + return response.json({'hello': 'world'}) + + request, response = app.test_client.head('/') + assert response.headers['Allow']== 'GET' + + @app.post('/') + async def test(request): + return response.json({'hello': 'world'}) + + request, response = app.test_client.head('/') + assert response.status == 405 + assert set(response.headers['Allow'].split(', ')) == set(['GET', 'POST']) + assert response.headers['Content-Length'] == '0' + @pytest.fixture def json_app(): @@ -254,4 +273,4 @@ def test_file_stream_head_response(file_name, static_file_directory): assert 'Content-Length' in response.headers assert int(response.headers[ 'Content-Length']) == len( - get_file_content(static_file_directory, file_name)) \ No newline at end of file + get_file_content(static_file_directory, file_name))