add support for virtual hosts

This commit is contained in:
Raphael Deem 2017-01-08 15:48:12 -08:00
parent b0bf989056
commit 4f832ac9af
4 changed files with 71 additions and 10 deletions

18
examples/vhosts.py Normal file
View File

@ -0,0 +1,18 @@
from sanic.response import text
from sanic import Sanic
# Usage
# curl -H "Host: example.com" localhost:8000
# curl -H "Host: sub.example.com" localhost:8000
app = Sanic()
@app.route('/', host="example.com")
async def hello(request):
return text("Answer")
@app.route('/', host="sub.example.com")
async def hello(request):
return text("42")
if __name__ == '__main__':
app.run(host="0.0.0.0", port=8000)

View File

@ -55,8 +55,9 @@ class Router:
self.routes_static = {}
self.routes_dynamic = defaultdict(list)
self.routes_always_check = []
self.hosts = None
def add(self, uri, methods, handler):
def add(self, uri, methods, handler, host=None):
"""
Adds a handler to the route list
:param uri: Path to match
@ -66,6 +67,17 @@ class Router:
When executed, it should provide a response object.
:return: Nothing
"""
if host is not None:
# we want to track if there are any
# vhosts on the Router instance so that we can
# default to the behavior without vhosts
if self.hosts is None:
self.hosts = set(host)
else:
self.hosts.add(host)
uri = host + uri
if uri in self.routes_all:
raise RouteExists("Route already registered: {}".format(uri))
@ -113,7 +125,9 @@ class Router:
else:
self.routes_static[uri] = route
def remove(self, uri, clean_cache=True):
def remove(self, uri, clean_cache=True, host=None):
if host is not None:
uri = host + uri
try:
route = self.routes_all.pop(uri)
except KeyError:
@ -137,10 +151,14 @@ class Router:
:param request: Request object
:return: handler, arguments, keyword arguments
"""
return self._get(request.url, request.method)
if self.hosts is None:
return self._get(request.url, request.method, '')
else:
return self._get(request.url, request.method,
request.headers.get("Host", ''))
@lru_cache(maxsize=Config.ROUTER_CACHE_SIZE)
def _get(self, url, method):
def _get(self, url, method, host=None):
"""
Gets a request handler based on the URL of the request, or raises an
error. Internal method for caching.
@ -148,6 +166,7 @@ class Router:
:param method: Request method
:return: handler, arguments, keyword arguments
"""
url = host + url
# Check against known static routes
route = self.routes_static.get(url)
if route:

View File

@ -51,7 +51,7 @@ class Sanic:
# -------------------------------------------------------------------- #
# Decorator
def route(self, uri, methods=None):
def route(self, uri, methods=None, host=None):
"""
Decorates a function to be registered as a route
:param uri: path of the URL
@ -65,12 +65,13 @@ class Sanic:
uri = '/' + uri
def response(handler):
self.router.add(uri=uri, methods=methods, handler=handler)
self.router.add(uri=uri, methods=methods, handler=handler,
host=host)
return handler
return response
def add_route(self, handler, uri, methods=None):
def add_route(self, handler, uri, methods=None, host=None):
"""
A helper method to register class instance or
functions as a handler to the application url
@ -80,11 +81,11 @@ class Sanic:
:param methods: list or tuple of methods allowed
:return: function or class instance
"""
self.route(uri=uri, methods=methods)(handler)
self.route(uri=uri, methods=methods, host=host)(handler)
return handler
def remove_route(self, uri, clean_cache=True):
self.router.remove(uri, clean_cache)
def remove_route(self, uri, clean_cache=True, host=None):
self.router.remove(uri, clean_cache, host)
# Decorator
def exception(self, *exceptions):

23
tests/test_vhosts.py Normal file
View File

@ -0,0 +1,23 @@
from sanic import Sanic
from sanic.response import json, text
from sanic.utils import sanic_endpoint_test
def test_vhosts():
app = Sanic('test_text')
@app.route('/', host="example.com")
async def handler(request):
return text("You're at example.com!")
@app.route('/', host="subdomain.example.com")
async def handler(request):
return text("You're at subdomain.example.com!")
headers = {"Host": "example.com"}
request, response = sanic_endpoint_test(app, headers=headers)
assert response.text == "You're at example.com!"
headers = {"Host": "subdomain.example.com"}
request, response = sanic_endpoint_test(app, headers=headers)
assert response.text == "You're at subdomain.example.com!"