P G Jones

How to write ASGI middleware

2021-02-07

Recently a Hypercorn user asked how to serve multiple domains with a different ASGI application serving each different domain. This is something that can be solved by middleware, and is similar to the existing Dispatcher Middleware which dispatches based on the request path.

Middleware, in the context of ASGI, wrap ASGI applications to add or alter the functionality. As they wrap the application they run first and can choose whether to alter the application's messages or even completly ignore them.

Basic middleware

ASGI middleware must be an ASGI callable, i.e. implement this interface,

async def app(
    scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
    ...

see Hypercorn's typing for the types used.

To add additional functionality the middleware is usually initialised with an ASGI application, although it can be initialised however you like. For our basic example we'll wrap a single app, either functionally,

async def middleware(app: ASGI3Framework) -> ASGI3Framework:
    async def app(
        scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
    ) -> None:
        await app(scope, receive, send)

or via an object,

class BasicMiddleware:
    def __init__(self, app: ASGI3Framework) -> None:
        self._app = app

    def __call__(
        self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
    ) -> None:
        await self._app(scope, receive, send)

Domain dispatching

To solve the user's problem and dispatch based on the request's target-domain we can make use of the fact that ASGI servers normalise any :authority pseudo headers (HTTP/2 & HTTP/3) to a host header (HTTP/1.1), which means the middleware can match on the host header to dispatch.

Firstly lets initialise the middleware with a dictionary of domain names mapped to ASGI applications,

class DomainDispatcherMiddleware:
    def __init__(self, mounts: Dict[str, ASGIFramework]) -> None:
        self.mounts = mounts

then we can dispatch based on the host header match, or 404 if there is no match,

class DomainDispatcherMiddleware:
    ...
    async def __call__(
        self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
    ) -> None:
        if scope["type"] in {"http", "websocket"}:
            host = self._get_host(scope["headers"])
            app = self.mounts.get(host)
            if app is not None:
                await app(scope, receive, send)
            else:
                await send(
                    {
                        "type": "http.response.start",
                        "status": 404,
                        "headers": [(b"content-length", b"0")],
                    }
                )
                await send({"type": "http.response.body"})

    def _get_host(self, headers: Tuple[bytes, bytes]) -> Optional[str]:
        for name, value in scope["headers"]:
            if name.lower() == b"host":
                return value.decode()

Managing the lifespan

The domain dispatching middleware above forwards http, and websocket scopes but ignores the lifespan scope. This may result in the wrapped ASGI applications not initialising correctly. Therefore it is important that middleware manages the lifespan.

For middleware that wraps a single app this is easy, as the lifespan scope and all messages can be passed directly to the app. However for the domain dispatching middleware and any other middleware that manage multiple apps, care must be taken to ensure all apps receive the lifespan scope and messages.

Managing the lifespan scope for a multi-app middleware requires invoking all the apps as separate tasks, and then only reporting completion when all the apps have reported complete. This can be achieved as follows,

class MultiAppMiddleware:
    ...
    async def __call__(
        self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
    ) -> None:
        if scope["type"] == "lifespan":
            self.app_queues: Dict[str, asyncio.Queue] = {
                domain: asyncio.Queue(MAX_QUEUE_SIZE) for domain in self.mounts
            }
            self.startup_complete = {domain: False for domain in self.mounts}
            self.shutdown_complete = {domain: False for domain in self.mounts}

            for domain, app in self.mounts.values():
                asyncio.ensure_future(
                    app(
                        scope,
                        self.app_queues[domain].get,
                        partial(self.send, domain, send),
                    )
                )

            while True:
                message = await receive()
                for queue in self.app_queues.values():
                    await queue.put(message)
                if message["type"] == "lifespan.shutdown":
                    break
        elif scope["type"] in {"http", "websocket"}:
            host = self._get_host(scope["headers"])
            app = self.mounts.get(host)
            if app is not None:
                await app(scope, receive, send)
            else:
                await send(
                    {
                        "type": "http.response.start",
                        "status": 404,
                        "headers": [(b"content-length", b"0")],
                    }
                )
                await send({"type": "http.response.body"})

    def _get_host(self, headers: Tuple[bytes, bytes]) -> Optional[str]:
        for name, value in scope["headers"]:
            if name.lower() == b"host":
                return value.decode()

    async def send(self, domain: str, send: Callable, message: dict) -> None:
        if message["type"] == "lifespan.startup.complete":
            self.startup_complete[domain] = True
            if all(self.startup_complete.values()):
                await send({"type": "lifespan.startup.complete"})
        elif message["type"] == "lifespan.shutdown.complete":
            self.shutdown_complete[domain] = True
            if all(self.shutdown_complete.values()):
                await send({"type": "lifespan.shutdown.complete"})

Usage

The middleware can be used by wrapping any app, e.g.

app = DomainDispatcherMiddleware({
    "example.com": example_app,
    "another.com": another_app,
})
app = BasicMiddleware(app)

Note if you are using Quart apps, consider using the app.asgi_app attribute as described in the docs as this will make testing easier.

Complete solution

Putting it all together,

class DomainDispatcherMiddleware:
    def __init__(self, mounts: Dict[str, ASGIFramework]) -> None:
        self.mounts = mounts

    async def __call__(
        self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
    ) -> None:
        if scope["type"] == "lifespan":
            self.app_queues: Dict[str, asyncio.Queue] = {
                domain: asyncio.Queue(MAX_QUEUE_SIZE) for domain in self.mounts
            }
            self.startup_complete = {domain: False for domain in self.mounts}
            self.shutdown_complete = {domain: False for domain in self.mounts}

            for domain, app in self.mounts.values():
                asyncio.ensure_future(
                    app(
                        scope,
                        self.app_queues[domain].get,
                        partial(self.send, domain, send)),
                    )
                )

            while True:
                message = await receive()
                for queue in self.app_queues.values():
                    await queue.put(message)
                if message["type"] == "lifespan.shutdown":
                    break

    async def send(self, domain: str, send: Callable, message: dict) -> None:
        if message["type"] == "lifespan.startup.complete":
            self.startup_complete[domain] = True
            if all(self.startup_complete.values()):
                await send({"type": "lifespan.startup.complete"})
        elif message["type"] == "lifespan.shutdown.complete":
            self.shutdown_complete[domain] = True
            if all(self.shutdown_complete.values()):
                await send({"type": "lifespan.shutdown.complete"})