Merge pull request #1054 from r0fls/rfc7231

fix issues with method not allowed response
This commit is contained in:
Raphael Deem 2017-12-13 23:42:37 -08:00 committed by GitHub
commit 72254a7af9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 10 deletions

View File

@ -544,6 +544,7 @@ class Sanic:
# Fetch handler from router # Fetch handler from router
handler, args, kwargs, uri = self.router.get(request) handler, args, kwargs, uri = self.router.get(request)
request.uri_template = uri request.uri_template = uri
if handler is None: if handler is None:
raise ServerError( raise ServerError(

View File

@ -150,6 +150,16 @@ class InvalidUsage(SanicException):
pass pass
@add_status_code(405)
class MethodNotSupported(SanicException):
def __init__(self, message, method, allowed_methods):
super().__init__(message)
self.headers = dict()
self.headers["Allow"] = ", ".join(allowed_methods)
if method in ['HEAD', 'PATCH', 'PUT', 'DELETE']:
self.headers['Content-Length'] = 0
@add_status_code(500) @add_status_code(500)
class ServerError(SanicException): class ServerError(SanicException):
pass pass
@ -167,8 +177,6 @@ class URLBuildError(ServerError):
class FileNotFound(NotFound): class FileNotFound(NotFound):
pass
def __init__(self, message, path, relative_url): def __init__(self, message, path, relative_url):
super().__init__(message) super().__init__(message)
self.path = path self.path = path
@ -198,8 +206,6 @@ class HeaderNotFound(InvalidUsage):
@add_status_code(416) @add_status_code(416)
class ContentRangeError(SanicException): class ContentRangeError(SanicException):
pass
def __init__(self, message, content_range): def __init__(self, message, content_range):
super().__init__(message) super().__init__(message)
self.headers = { self.headers = {

View File

@ -3,7 +3,7 @@ from collections import defaultdict, namedtuple
from collections.abc import Iterable from collections.abc import Iterable
from functools import lru_cache from functools import lru_cache
from sanic.exceptions import NotFound, InvalidUsage from sanic.exceptions import NotFound, MethodNotSupported
from sanic.views import CompositionView from sanic.views import CompositionView
Route = namedtuple( Route = namedtuple(
@ -350,6 +350,16 @@ class Router:
except NotFound: except NotFound:
return self._get(request.path, request.method, '') return self._get(request.path, request.method, '')
def get_supported_methods(self, url):
"""Get a list of supported methods for a url and optional host.
:param url: URL string (including host)
:return: frozenset of supported methods
"""
route = self.routes_all.get(url)
# if methods are None then this logic will prevent an error
return getattr(route, 'methods', None) or frozenset()
@lru_cache(maxsize=ROUTER_CACHE_SIZE) @lru_cache(maxsize=ROUTER_CACHE_SIZE)
def _get(self, url, method, host): def _get(self, url, method, host):
"""Get a request handler based on the URL of the request, or raises an """Get a request handler based on the URL of the request, or raises an
@ -362,9 +372,10 @@ class Router:
url = host + url url = host + url
# Check against known static routes # Check against known static routes
route = self.routes_static.get(url) route = self.routes_static.get(url)
method_not_supported = InvalidUsage( method_not_supported = MethodNotSupported(
'Method {} not allowed for URL {}'.format( 'Method {} not allowed for URL {}'.format(method, url),
method, url), status_code=405) method=method,
allowed_methods=self.get_supported_methods(url))
if route: if route:
if route.methods and method not in route.methods: if route.methods and method not in route.methods:
raise method_not_supported raise method_not_supported
@ -407,7 +418,7 @@ class Router:
""" """
try: try:
handler = self.get(request)[0] handler = self.get(request)[0]
except (NotFound, InvalidUsage): except (NotFound, MethodNotSupported):
return False return False
if (hasattr(handler, 'view_class') and if (hasattr(handler, 'view_class') and
hasattr(handler.view_class, request.method.lower())): hasattr(handler.view_class, request.method.lower())):

View File

@ -35,6 +35,25 @@ async def sample_streaming_fn(response):
await asyncio.sleep(.001) await asyncio.sleep(.001)
response.write('bar') response.write('bar')
def test_method_not_allowed():
app = Sanic('method_not_allowed')
@app.get('/')
async def test(request):
return response.json({'hello': 'world'})
request, response = app.test_client.head('/')
assert response.headers['Allow']== 'GET'
@app.post('/')
async def test(request):
return response.json({'hello': 'world'})
request, response = app.test_client.head('/')
assert response.status == 405
assert set(response.headers['Allow'].split(', ')) == set(['GET', 'POST'])
assert response.headers['Content-Length'] == '0'
@pytest.fixture @pytest.fixture
def json_app(): def json_app():
@ -254,4 +273,4 @@ def test_file_stream_head_response(file_name, static_file_directory):
assert 'Content-Length' in response.headers assert 'Content-Length' in response.headers
assert int(response.headers[ assert int(response.headers[
'Content-Length']) == len( 'Content-Length']) == len(
get_file_content(static_file_directory, file_name)) get_file_content(static_file_directory, file_name))