diff --git a/examples/cors.html b/examples/cors.html new file mode 100644 index 00000000..642a102e --- /dev/null +++ b/examples/cors.html @@ -0,0 +1,50 @@ + + + + + Cors Example + + + + + + + \ No newline at end of file diff --git a/examples/cors_example.py b/examples/cors_example.py new file mode 100644 index 00000000..c60c7b6d --- /dev/null +++ b/examples/cors_example.py @@ -0,0 +1,31 @@ +from sanic import Sanic +from sanic import response + +app = Sanic(__name__) + + +@app.route("/") +@response.cors() +async def test(request): + return response.json({"test": True}) + + +@app.route("/t2") +@response.cors() +def test2(request): + return response.json({"test": True}) + + +@app.websocket('/feed') +@response.cors() +async def feed(request, ws): + while True: + data = 'hello!' + print('Sending: ' + data) + await ws.send(data) + data = await ws.recv() + print('Received: ' + data) + + +if __name__ == '__main__': + app.run(host="0.0.0.0", port=8000) diff --git a/sanic/response.py b/sanic/response.py index 582e11cf..002878e5 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -1,5 +1,7 @@ from mimetypes import guess_type from os import path +from functools import wraps +from inspect import isawaitable try: from ujson import dumps as json_dumps @@ -9,6 +11,7 @@ except: from aiofiles import open as open_async from sanic.cookies import CookieJar +from sanic.constants import HTTP_METHODS COMMON_STATUS_CODES = { 200: b'OK', @@ -427,3 +430,45 @@ def redirect(to, headers=None, status=302, status=status, headers=headers, content_type=content_type) + + +def cors(origin=None, allow_methods=None): + if isinstance(allow_methods, (list, tuple, set)): + allow_methods = ', '.join( + filter( + lambda x: x in HTTP_METHODS, + map( + lambda x: x.upper(), + allow_methods + ) + ) + ) + elif allow_methods: + raise ValueError('allow_methods must be instance of list, tuple or set.') + + cors_headers = { + 'Access-Control-Allow-Credentials': 'true', + 'Access-Control-Allow-Methods': allow_methods or 'GET', + 'Access-Control-Allow-Origin': origin or '*' + } + + def decorator(fn): + @wraps(fn) + def wrap(*args, **kwargs): + res = fn(*args, **kwargs) + if isinstance(res, BaseHTTPResponse): + res.headers.update(cors_headers) + return res + elif isawaitable(res): + async def make_cors(): + response = await res + response.headers.update(cors_headers) + return response + + return make_cors() + + return res + + return wrap + + return decorator