Skip to content

EXPLAIN asynchronous SQLAlchemy queries

The easiest way to understand the performance characteristics of database operations is to run EXPLAIN on them and read their query execution plan. SQLAlchemy provides s way to log out the SQL statements it runs, but those only tell so much. Let's find a way to automatically print out the execution plan too.

First attempt

The simplest SQLAlchemy setup looks something like this:

engine = create_async_engine(url)
sessionmaker = async_sessionmaker(bind=engine)

async with sessionmaker.session() as session:
    await session.execute(select(Foobar))

Simple and easy, you can use the sessionmaker to create asynchronous sessions and run queries. Now, async_sessionmaker has an echo parameter we can use to log out all SQL statements those queries actually execute. I want to do better, though.

SQLAlchemy's overly engineered tendencies end up helping us here. async_sessionmaker also takes in a parameter for the actual class constructed by .session(). Let's override that and capture the queries.

This approach is NOT production-safe. See the documentation for literal_binds. The final solution is ok, but obviously very slow.

from sqlalchemy.dialects import postgresql

class AsyncSessionExplain(AsyncSession):
    async def execute(
        self,
        statement: Any,
        *k: Any,
        **kw: Any,
    ) -> Any:
        sql = statement.compile(
            compile_kwargs={"literal_binds": True}, dialect=postgresql.dialect()
        )
        explain = await super().execute(text(f"EXPLAIN {sql}"))

        print(
            f"\nExplain for {sql}:\n\n" + "\n".join(r[0] for r in explain.all())
        )

        return await super().execute(statement, *k, **kw)

sessionmaker = async_sessionmaker(
    bind=engine,
    class_=AsyncSessionExplain
)

This is simple enough, we take the query executed, compile it to SQL and run EXPLAIN on it. However, there's a massive problem: one SQLAlchemy query can run arbitrarily many SQL statements and statement.compile is not smart enough to know that.

Second attempt

So, we need to somehow capture all SQL statements run inside the depths of execute. I'll be entirely honest - I tried to figure out where that happens and gave up. The code is too difficult to navigate. However, SQLAlchemy does provide an event we can use to listen for the statements it runs.

def before_cursor_execute(conn, cursor, statement, parameters, context, executemany) -> None:
    print(statement, parameters)

engine = create_async_engine(url)
event.listen(
    engine.sync_engine, "before_cursor_execute", before_cursor_execute
)

This is excellent, we get to capture all of the queries. But, note that this event handler is run inside an asynchronous engine, which runs in greenlets. How do we associate this captured statement with the correct asynchronous call?

Context variables to the rescue! Each asynchronous task has its own unique context, which is then also passed down to the greenlet and thus the event.

TRACKED_STATEMENTS = ContextVar[list[tuple[str, tuple[Any]]]]("TRACKED_STATEMENTS")

class AsyncSessionExplain(AsyncSession):
    async def execute(
        self,
        statement: Any,
        *k: Any,
        **kw: Any,
    ) -> Any:
        stmts: list[tuple[str, tuple[Any]]] = []
        token = TRACKED_STATEMENTS.set(stmts)
        try:
            result = await super().explain(statement, *args, **kwargs)
        finally:
            TRACKED_STATEMENTS.reset(token)

        for stmt, params in stmts:
            print(stmt)

        return result


def before_cursor_execute(conn, cursor, statement, parameters, context, executemany) -> None:
    try:
        stmts = TRACKED_STATEMENTS.get()
        stmts.append((statement, parameters))
    except LookupError:
        logger.info("Running SQL outside a tracked context: \n%s", statement)

This works and we get a full list of the statements executed. However, these statements have variable placeholders (for example, SELECT * FROM user WGERE id = $1). This would not be an issue with postgresql's EXPLAIN (GENERIC_PLAN) and realistically we should be able to just run the following.

explain = await super().execute(text(f"EXPLAIN (GENERIC_PLAN) {stmt}"))

However, SQLAlchemy's executor notices the parameters and notices we didn't pass any. Errors are thrown, everybody is sad. Luckily we also have access to the parameters. Unluckily, the statement is at this point in a formst SQLAlchemy doesn't understand, so we need to do some massaging.

for stmt, params in stmts:
    stmt = re.sub(r"\$(\d+)", r" :\1 ", stmt)
    explain = await super().execute(
        text(f"EXPLAIN {stmt}").bindparams(
            **{str(i + 1): x for i, x in enumerate(params)}
        ),
    )

This works, but we still have two issues:

  1. Not all AsyncSession functions funnel through execute, so for example session.get_one is missed here.
  2. The explain calls themselves get logged as being tracked outside a tracking context.

The solution

Let's put all of that together and abstract the tracking for the other functions.

TRACKED_STATEMENTS = ContextVar[list[tuple[str, tuple[Any]]]]("TRACKED_STATEMENTS")

def track_call[T, **P](
    func: Callable[Concatenate[AsyncSession, P], Coroutine[Any, Any, T]],
) -> Callable[Concatenate[AsyncSession, P], Coroutine[Any, Any, T]]:
    async def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
        stmts: list[tuple[str, tuple[Any]]] = []
        token = TRACKED_STATEMENTS.set(stmts)
        try:
            result = await func(self, *args, **kwargs)
        finally:
            TRACKED_STATEMENTS.reset(token)

        token = TRACKED_STATEMENTS.set([])
        try:
            for stmt, params in stmts:
                stmt = re.sub(r"\$(\d+)", r" :\1 ", stmt)
                explain = await AsyncSession.execute(
                    self,
                    text(f"EXPLAIN {stmt}").bindparams(
                        **{str(i + 1): x for i, x in enumerate(params)}
                    ),
                )

                print(
                    f"\nExplain for {stmt}:\n\n"
                    + "\n".join(r[0] for r in explain.all())
                )
        finally:
            TRACKED_STATEMENTS.reset(token)

        return result

    return wrapper


class AsyncSessionExplain(AsyncSession):
    execute = track_call(AsyncSession.execute)
    get_one = track_call(AsyncSession.get_one)
    refresh = track_call(AsyncSession.refresh)


def before_cursor_execute(conn, cursor, statement, parameters, context, executemany) -> None:
    try:
        stmts = TRACKED_STATEMENTS.get()
        stmts.append((statement, parameters))
    except LookupError:
        logger.info("Running SQL outside a tracked context: \n%s", statement)

engine = create_async_engine(url)
event.listen(
    engine.sync_engine, "before_cursor_execute", before_cursor_execute
)

sessionmaker = async_sessionmaker(
    bind=engine,
    class_=AsyncSessionExplain
)

async with sessionmaker.session() as session:
    await session.execute(select(Foobar))