from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Generic, Protocol, cast, overload

from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from starlette.concurrency import run_in_threadpool
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request  # noqa: TC002

from advanced_alchemy.config.common import EngineT, SessionT
from advanced_alchemy.exceptions import ImproperConfigurationError

    from sqlalchemy.orm import Session
    from starlette.applications import Starlette
    from starlette.responses import Response

    from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig
    from advanced_alchemy.config.sync import SQLAlchemySyncConfig
    from advanced_alchemy.config.types import CommitStrategy

__all__ = ("CommitStrategyExecutor", "StarletteAdvancedAlchemy")

[docs] class CommitStrategyExecutor(Protocol): """Protocol for commit strategy executors. Defines the callable signature for commit strategy functions. Args: session (sqlalchemy.orm.Session | sqlalchemy.ext.asyncio.AsyncSession): The database session. response (starlette.responses.Response): The HTTP response. Returns: None """ async def __call__(self, *, session: Session | AsyncSession, response: Response) -> None: ...
[docs] class StarletteAdvancedAlchemy(Generic[EngineT, SessionT]): """AdvancedAlchemy integration for Starlette applications. This class manages SQLAlchemy sessions and engine lifecycle within a Starlette application. It provides middleware for handling transactions based on commit strategies. Args: config (advanced_alchemy.config.asyncio.SQLAlchemyAsyncConfig | advanced_alchemy.config.sync.SQLAlchemySyncConfig): The SQLAlchemy configuration. autocommit (advanced_alchemy.config.types.CommitStrategy | None): The commit strategy to use. Defaults to None. app (starlette.applications.Starlette | None): The Starlette application instance. Defaults to None. """ @overload def __init__( self: StarletteAdvancedAlchemy[AsyncEngine, AsyncSession], config: SQLAlchemyAsyncConfig, autocommit: CommitStrategy | None = None, app: Starlette | None = None, ) -> None: ... @overload def __init__( self: StarletteAdvancedAlchemy[Engine, Session], config: SQLAlchemySyncConfig, autocommit: CommitStrategy | None = None, app: Starlette | None = None, ) -> None: ...
[docs] def __init__( self, config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, autocommit: CommitStrategy | None = None, app: Starlette | None = None, ) -> None: self.config = config self._app: Starlette self.engine_key: str self.sessionmaker_key: str self.session_key: str self.autocommit_strategy = autocommit self._commit_strategies: dict[CommitStrategy, CommitStrategyExecutor] = { "always": self._commit_strategy_always, "match_status": self._commit_strategy_match_status, } if app is not None: self.init_app(app)
@staticmethod def _make_unique_state_key(app: Starlette, key: str) -> str: """Generates a unique state key for the Starlette application. Ensures that the key does not already exist in the application's state. Args: app (starlette.applications.Starlette): The Starlette application instance. key (str): The base key name. Returns: str: A unique key name. """ i = 0 while True: if not hasattr(app.state, key): return key key = f"{key}_{i}" i += i
[docs] def init_app(self, app: Starlette) -> None: """Initializes the Starlette application with SQLAlchemy engine and sessionmaker. Sets up middleware and shutdown handlers for managing the database engine. Args: app (starlette.applications.Starlette): The Starlette application instance. """ engine = self.config.get_engine() self.engine_key = self._make_unique_state_key(app, f"sqla_engine_{}") self.sessionmaker_key = self._make_unique_state_key(app, f"sqla_sessionmaker_{}") self.session_key = f"sqla_session_{self.sessionmaker_key}" setattr(app.state, self.engine_key, engine) setattr(app.state, self.sessionmaker_key, self.config.create_session_maker()) app.add_middleware(BaseHTTPMiddleware, dispatch=self.middleware_dispatch) app.add_event_handler("shutdown", self.on_shutdown) # pyright: ignore[reportUnknownMemberType] self._app = app
async def _do_commit(self, session: Session | AsyncSession) -> None: """Commits the current transaction. Args: session (sqlalchemy.orm.Session | sqlalchemy.ext.asyncio.AsyncSession): The database session to commit. Returns: None """ if not isinstance(session, AsyncSession): await run_in_threadpool(session.commit) else: await session.commit() async def _do_rollback(self, session: Session | AsyncSession) -> None: """Rolls back the current transaction. Args: session (sqlalchemy.orm.Session | sqlalchemy.ext.asyncio.AsyncSession): The database session to rollback. Returns: None """ if not isinstance(session, AsyncSession): await run_in_threadpool(session.rollback) else: await session.rollback() async def _do_close(self, session: Session | AsyncSession) -> None: """Closes the database session. Args: session (sqlalchemy.orm.Session | sqlalchemy.ext.asyncio.AsyncSession): The database session to close. Returns: None """ if not isinstance(session, AsyncSession): await run_in_threadpool(session.close) else: await session.close() async def _commit_strategy_always(self, *, session: Session | AsyncSession, response: Response) -> None: """Commits the transaction regardless of the response status. Args: session (sqlalchemy.orm.Session | sqlalchemy.ext.asyncio.AsyncSession): The database session to commit. response (starlette.responses.Response): The HTTP response. Returns: None """ await self._do_commit(session) async def _commit_strategy_match_status(self, *, session: Session | AsyncSession, response: Response) -> None: """Commits or rolls back based on the response status code. Commits if the status code is between 200 and 299, otherwise rolls back. Args: session (sqlalchemy.orm.Session | sqlalchemy.ext.asyncio.AsyncSession): The database session. response (starlette.responses.Response): The HTTP response. Returns: None """ if 200 <= response.status_code < 300: # noqa: PLR2004 await self._do_commit(session) else: await self._do_rollback(session)
[docs] async def session_handler(self, session: Session | AsyncSession, request: Request, response: Response) -> None: """Handles the session after a request is processed. Applies the commit strategy and ensures the session is closed. Args: session (sqlalchemy.orm.Session | sqlalchemy.ext.asyncio.AsyncSession): The database session. request (starlette.requests.Request): The incoming HTTP request. response (starlette.responses.Response): The outgoing HTTP response. Returns: None """ try: if self.autocommit_strategy: await self._commit_strategies[self.autocommit_strategy](session=session, response=response) # pyright: ignore[reportArgumentType] finally: await self._do_close(session) delattr(request.state, self.session_key)
@property def app(self) -> Starlette: """Returns the Starlette application instance. Raises: advanced_alchemy.exceptions.ImproperConfigurationError: If the application is not initialized. Returns: starlette.applications.Starlette: The Starlette application instance. """ try: return self._app except AttributeError as e: msg = "Application not initialized. Did you forget to call init_app?" raise ImproperConfigurationError(msg) from e
[docs] def get_engine(self) -> EngineT: """Retrieves the SQLAlchemy engine from the application state. Returns: EngineT: The SQLAlchemy engine instance. """ return cast(EngineT, getattr(, self.engine_key))
[docs] def get_sessionmaker(self) -> Callable[[], SessionT]: """Retrieves the sessionmaker from the application state. Returns: Callable[[], SessionT]: A callable that returns a new session instance. """ return cast(Callable[[], SessionT], getattr(, self.sessionmaker_key))
[docs] def get_session(self, request: Request) -> SessionT: """Retrieves or creates a database session for the current request. Args: request (starlette.requests.Request): The incoming HTTP request. Returns: SessionT: The database session. """ session = getattr(request.state, self.session_key, None) if session is not None: return cast(SessionT, session) session = self.get_sessionmaker()() setattr(request.state, self.session_key, session) return session
[docs] async def middleware_dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: """Middleware dispatch function to handle requests and responses. Processes the request, invokes the next middleware or route handler, and applies the session handler after the response is generated. Args: request (starlette.requests.Request): The incoming HTTP request. call_next (starlette.middleware.base.RequestResponseEndpoint): The next middleware or route handler. Returns: starlette.responses.Response: The HTTP response. """ response = await call_next(request) session: Session | AsyncSession | None = getattr(request.state, self.session_key, None) if session is not None: await self.session_handler(session=session, request=request, response=response) return response
[docs] async def on_shutdown(self) -> None: """Handles the shutdown event by disposing of the SQLAlchemy engine. Ensures that all connections are properly closed during application shutdown. Returns: None """ engine = getattr(, self.engine_key) if isinstance(engine, Engine): await run_in_threadpool(engine.dispose) else: await engine.dispose() delattr(, self.engine_key) delattr(, self.sessionmaker_key)