Merge pull request #1054 from r0fls/rfc7231
fix issues with method not allowed response
This commit is contained in:
commit
72254a7af9
|
@ -544,6 +544,7 @@ class Sanic:
|
|||
|
||||
# Fetch handler from router
|
||||
handler, args, kwargs, uri = self.router.get(request)
|
||||
|
||||
request.uri_template = uri
|
||||
if handler is None:
|
||||
raise ServerError(
|
||||
|
|
|
@ -150,6 +150,16 @@ class InvalidUsage(SanicException):
|
|||
pass
|
||||
|
||||
|
||||
@add_status_code(405)
|
||||
class MethodNotSupported(SanicException):
|
||||
def __init__(self, message, method, allowed_methods):
|
||||
super().__init__(message)
|
||||
self.headers = dict()
|
||||
self.headers["Allow"] = ", ".join(allowed_methods)
|
||||
if method in ['HEAD', 'PATCH', 'PUT', 'DELETE']:
|
||||
self.headers['Content-Length'] = 0
|
||||
|
||||
|
||||
@add_status_code(500)
|
||||
class ServerError(SanicException):
|
||||
pass
|
||||
|
@ -167,8 +177,6 @@ class URLBuildError(ServerError):
|
|||
|
||||
|
||||
class FileNotFound(NotFound):
|
||||
pass
|
||||
|
||||
def __init__(self, message, path, relative_url):
|
||||
super().__init__(message)
|
||||
self.path = path
|
||||
|
@ -198,8 +206,6 @@ class HeaderNotFound(InvalidUsage):
|
|||
|
||||
@add_status_code(416)
|
||||
class ContentRangeError(SanicException):
|
||||
pass
|
||||
|
||||
def __init__(self, message, content_range):
|
||||
super().__init__(message)
|
||||
self.headers = {
|
||||
|
|
|
@ -3,7 +3,7 @@ from collections import defaultdict, namedtuple
|
|||
from collections.abc import Iterable
|
||||
from functools import lru_cache
|
||||
|
||||
from sanic.exceptions import NotFound, InvalidUsage
|
||||
from sanic.exceptions import NotFound, MethodNotSupported
|
||||
from sanic.views import CompositionView
|
||||
|
||||
Route = namedtuple(
|
||||
|
@ -350,6 +350,16 @@ class Router:
|
|||
except NotFound:
|
||||
return self._get(request.path, request.method, '')
|
||||
|
||||
def get_supported_methods(self, url):
|
||||
"""Get a list of supported methods for a url and optional host.
|
||||
|
||||
:param url: URL string (including host)
|
||||
:return: frozenset of supported methods
|
||||
"""
|
||||
route = self.routes_all.get(url)
|
||||
# if methods are None then this logic will prevent an error
|
||||
return getattr(route, 'methods', None) or frozenset()
|
||||
|
||||
@lru_cache(maxsize=ROUTER_CACHE_SIZE)
|
||||
def _get(self, url, method, host):
|
||||
"""Get a request handler based on the URL of the request, or raises an
|
||||
|
@ -362,9 +372,10 @@ class Router:
|
|||
url = host + url
|
||||
# Check against known static routes
|
||||
route = self.routes_static.get(url)
|
||||
method_not_supported = InvalidUsage(
|
||||
'Method {} not allowed for URL {}'.format(
|
||||
method, url), status_code=405)
|
||||
method_not_supported = MethodNotSupported(
|
||||
'Method {} not allowed for URL {}'.format(method, url),
|
||||
method=method,
|
||||
allowed_methods=self.get_supported_methods(url))
|
||||
if route:
|
||||
if route.methods and method not in route.methods:
|
||||
raise method_not_supported
|
||||
|
@ -407,7 +418,7 @@ class Router:
|
|||
"""
|
||||
try:
|
||||
handler = self.get(request)[0]
|
||||
except (NotFound, InvalidUsage):
|
||||
except (NotFound, MethodNotSupported):
|
||||
return False
|
||||
if (hasattr(handler, 'view_class') and
|
||||
hasattr(handler.view_class, request.method.lower())):
|
||||
|
|
|
@ -35,6 +35,25 @@ async def sample_streaming_fn(response):
|
|||
await asyncio.sleep(.001)
|
||||
response.write('bar')
|
||||
|
||||
def test_method_not_allowed():
|
||||
app = Sanic('method_not_allowed')
|
||||
|
||||
@app.get('/')
|
||||
async def test(request):
|
||||
return response.json({'hello': 'world'})
|
||||
|
||||
request, response = app.test_client.head('/')
|
||||
assert response.headers['Allow']== 'GET'
|
||||
|
||||
@app.post('/')
|
||||
async def test(request):
|
||||
return response.json({'hello': 'world'})
|
||||
|
||||
request, response = app.test_client.head('/')
|
||||
assert response.status == 405
|
||||
assert set(response.headers['Allow'].split(', ')) == set(['GET', 'POST'])
|
||||
assert response.headers['Content-Length'] == '0'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_app():
|
||||
|
|
Loading…
Reference in New Issue
Block a user