EXPLAIN asynchronous SQLAlchemy queries
The easiest way to understand the performance characteristics of database
operations is to run EXPLAIN on them and read their query execution plan.
SQLAlchemy provides s way to log out the SQL statements it runs, but those
only tell so much. Let's find a way to automatically print out the execution
plan too.
First attempt
The simplest SQLAlchemy setup looks something like this:
engine = create_async_engine(url)
sessionmaker = async_sessionmaker(bind=engine)
async with sessionmaker.session() as session:
await session.execute(select(Foobar))
Simple and easy, you can use the sessionmaker to create asynchronous sessions
and run queries. Now, async_sessionmaker has an echo parameter we can use to
log out all SQL statements those queries actually execute. I want to do better,
though.
SQLAlchemy's overly engineered tendencies end up helping us here.
async_sessionmaker also takes in a parameter for the actual class constructed
by .session(). Let's override that and capture the queries.
This approach is NOT production-safe. See the documentation for literal_binds.
The final solution is ok, but obviously very slow.
from sqlalchemy.dialects import postgresql
class AsyncSessionExplain(AsyncSession):
async def execute(
self,
statement: Any,
*k: Any,
**kw: Any,
) -> Any:
sql = statement.compile(
compile_kwargs={"literal_binds": True}, dialect=postgresql.dialect()
)
explain = await super().execute(text(f"EXPLAIN {sql}"))
print(
f"\nExplain for {sql}:\n\n" + "\n".join(r[0] for r in explain.all())
)
return await super().execute(statement, *k, **kw)
sessionmaker = async_sessionmaker(
bind=engine,
class_=AsyncSessionExplain
)
This is simple enough, we take the query executed, compile it to SQL and run
EXPLAIN on it. However, there's a massive problem: one SQLAlchemy query can
run arbitrarily many SQL statements and statement.compile is not smart enough
to know that.
Second attempt
So, we need to somehow capture all SQL statements run inside the depths of
execute. I'll be entirely honest - I tried to figure out where that happens
and gave up. The code is too difficult to navigate. However, SQLAlchemy does
provide an event we can use to listen for the statements it runs.
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany) -> None:
print(statement, parameters)
engine = create_async_engine(url)
event.listen(
engine.sync_engine, "before_cursor_execute", before_cursor_execute
)
This is excellent, we get to capture all of the queries. But, note that this event handler is run inside an asynchronous engine, which runs in greenlets. How do we associate this captured statement with the correct asynchronous call?
Context variables to the rescue! Each asynchronous task has its own unique context, which is then also passed down to the greenlet and thus the event.
TRACKED_STATEMENTS = ContextVar[list[tuple[str, tuple[Any]]]]("TRACKED_STATEMENTS")
class AsyncSessionExplain(AsyncSession):
async def execute(
self,
statement: Any,
*k: Any,
**kw: Any,
) -> Any:
stmts: list[tuple[str, tuple[Any]]] = []
token = TRACKED_STATEMENTS.set(stmts)
try:
result = await super().explain(statement, *args, **kwargs)
finally:
TRACKED_STATEMENTS.reset(token)
for stmt, params in stmts:
print(stmt)
return result
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany) -> None:
try:
stmts = TRACKED_STATEMENTS.get()
stmts.append((statement, parameters))
except LookupError:
logger.info("Running SQL outside a tracked context: \n%s", statement)
This works and we get a full list of the statements executed. However, these
statements have variable placeholders (for example,
SELECT * FROM user WGERE id = $1). This would not be an issue with
postgresql's EXPLAIN (GENERIC_PLAN) and realistically we should be able to
just run the following.
explain = await super().execute(text(f"EXPLAIN (GENERIC_PLAN) {stmt}"))
However, SQLAlchemy's executor notices the parameters and notices we didn't pass any. Errors are thrown, everybody is sad. Luckily we also have access to the parameters. Unluckily, the statement is at this point in a formst SQLAlchemy doesn't understand, so we need to do some massaging.
for stmt, params in stmts:
stmt = re.sub(r"\$(\d+)", r" :\1 ", stmt)
explain = await super().execute(
text(f"EXPLAIN {stmt}").bindparams(
**{str(i + 1): x for i, x in enumerate(params)}
),
)
This works, but we still have two issues:
- Not all AsyncSession functions funnel through
execute, so for examplesession.get_oneis missed here. - The explain calls themselves get logged as being tracked outside a tracking context.
The solution
Let's put all of that together and abstract the tracking for the other functions.
TRACKED_STATEMENTS = ContextVar[list[tuple[str, tuple[Any]]]]("TRACKED_STATEMENTS")
def track_call[T, **P](
func: Callable[Concatenate[AsyncSession, P], Coroutine[Any, Any, T]],
) -> Callable[Concatenate[AsyncSession, P], Coroutine[Any, Any, T]]:
async def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
stmts: list[tuple[str, tuple[Any]]] = []
token = TRACKED_STATEMENTS.set(stmts)
try:
result = await func(self, *args, **kwargs)
finally:
TRACKED_STATEMENTS.reset(token)
token = TRACKED_STATEMENTS.set([])
try:
for stmt, params in stmts:
stmt = re.sub(r"\$(\d+)", r" :\1 ", stmt)
explain = await AsyncSession.execute(
self,
text(f"EXPLAIN {stmt}").bindparams(
**{str(i + 1): x for i, x in enumerate(params)}
),
)
print(
f"\nExplain for {stmt}:\n\n"
+ "\n".join(r[0] for r in explain.all())
)
finally:
TRACKED_STATEMENTS.reset(token)
return result
return wrapper
class AsyncSessionExplain(AsyncSession):
execute = track_call(AsyncSession.execute)
get_one = track_call(AsyncSession.get_one)
refresh = track_call(AsyncSession.refresh)
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany) -> None:
try:
stmts = TRACKED_STATEMENTS.get()
stmts.append((statement, parameters))
except LookupError:
logger.info("Running SQL outside a tracked context: \n%s", statement)
engine = create_async_engine(url)
event.listen(
engine.sync_engine, "before_cursor_execute", before_cursor_execute
)
sessionmaker = async_sessionmaker(
bind=engine,
class_=AsyncSessionExplain
)
async with sessionmaker.session() as session:
await session.execute(select(Foobar))