Source code for advanced_alchemy.extensions.starlette.extension

# ruff: noqa: ARG001
from __future__ import annotations

import contextlib
from contextlib import asynccontextmanager, contextmanager
from typing import TYPE_CHECKING, AsyncGenerator, Callable, Generator, Sequence, Union, overload

from starlette.applications import Starlette
from starlette.requests import Request  # noqa: TC002

from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.extensions.starlette.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig

if TYPE_CHECKING:
    from sqlalchemy import Engine
    from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
    from sqlalchemy.orm import Session
    from starlette.applications import Starlette


[docs] class AdvancedAlchemy: """AdvancedAlchemy integration for Starlette applications. This class manages SQLAlchemy sessions and engine lifecycle within a Starlette application. It provides middleware for handling transactions based on commit strategies. Args: config (advanced_alchemy.config.asyncio.SQLAlchemyAsyncConfig | advanced_alchemy.config.sync.SQLAlchemySyncConfig): The SQLAlchemy configuration. app (starlette.applications.Starlette | None): The Starlette application instance. Defaults to None. """
[docs] def __init__( self, config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig | Sequence[SQLAlchemyAsyncConfig | SQLAlchemySyncConfig], app: Starlette | None = None, ) -> None: self._config = config if isinstance(config, Sequence) else [config] self._mapped_configs: dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]] = self.map_configs() # noqa: UP007 self._app: Starlette | None = None if app is not None: self.init_app(app)
@property def config(self) -> Sequence[SQLAlchemyAsyncConfig | SQLAlchemySyncConfig]: """Current Advanced Alchemy configuration.""" return self._config
[docs] def init_app(self, app: Starlette) -> None: """Initializes the Starlette application with SQLAlchemy engine and sessionmaker. Sets up middleware and shutdown handlers for managing the database engine. Args: app (starlette.applications.Starlette): The Starlette application instance. """ self._app = app unique_bind_keys = {config.bind_key for config in self.config} if len(unique_bind_keys) != len(self.config): # pragma: no cover msg = "Please ensure that each config has a unique name for the `bind_key` attribute. The default is `default` and can only be bound to a single engine." raise ImproperConfigurationError(msg) for config in self.config: config.init_app(app) app.state.advanced_alchemy = self
@property def app(self) -> Starlette: """Returns the Starlette application instance. Raises: advanced_alchemy.exceptions.ImproperConfigurationError: If the application is not initialized. Returns: starlette.applications.Starlette: The Starlette application instance. """ if self._app is None: msg = "Application not initialized. Did you forget to call init_app?" raise ImproperConfigurationError(msg) return self._app
[docs] async def on_startup(self) -> None: # pragma: no cover """Initializes the database.""" for config in self.config: await config.on_startup()
[docs] async def on_shutdown(self) -> None: # pragma: no cover """Handles the shutdown event by disposing of the SQLAlchemy engine. Ensures that all connections are properly closed during application shutdown. Returns: None """ for config in self.config: await config.on_shutdown() with contextlib.suppress(AttributeError, KeyError): delattr(self.app.state, "advanced_alchemy")
[docs] def map_configs(self) -> dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]: # noqa: UP007 """Maps the configs to the session bind keys.""" mapped_configs: dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]] = {} # noqa: UP007 for config in self.config: if config.bind_key is None: config.bind_key = "default" mapped_configs[config.bind_key] = config return mapped_configs
[docs] def get_config(self, key: str | None = None) -> Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]: # noqa: UP007 """Get the config for the given key.""" if key is None: key = "default" if key == "default" and len(self.config) == 1: key = self.config[0].bind_key or "default" config = self._mapped_configs.get(key) if config is None: msg = f"Config with key {key} not found" raise ImproperConfigurationError(msg) return config
[docs] def get_async_config(self, key: str | None = None) -> SQLAlchemyAsyncConfig: """Get the async config for the given key.""" config = self.get_config(key) if not isinstance(config, SQLAlchemyAsyncConfig): msg = "Expected an async config, but got a sync config" raise ImproperConfigurationError(msg) return config
[docs] def get_sync_config(self, key: str | None = None) -> SQLAlchemySyncConfig: """Get the sync config for the given key.""" config = self.get_config(key) if not isinstance(config, SQLAlchemySyncConfig): msg = "Expected a sync config, but got an async config" raise ImproperConfigurationError(msg) return config
[docs] @asynccontextmanager async def with_async_session(self, key: str | None = None) -> AsyncGenerator[AsyncSession, None]: """Context manager for getting an async session.""" config = self.get_async_config(key) async with config.get_session() as session: yield session
[docs] @contextmanager def with_sync_session(self, key: str | None = None) -> Generator[Session, None]: """Context manager for getting a sync session.""" config = self.get_sync_config(key) with config.get_session() as session: yield session
@overload @staticmethod def _get_session_from_request(request: Request, config: SQLAlchemyAsyncConfig) -> AsyncSession: ... @overload @staticmethod def _get_session_from_request(request: Request, config: SQLAlchemySyncConfig) -> Session: ... @staticmethod def _get_session_from_request( request: Request, config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig ) -> Session | AsyncSession: # pragma: no cover """Get the session for the given key.""" session = getattr(request.state, config.session_key, None) if session is None: session = config.create_session_maker()() setattr(request.state, config.session_key, session) return session
[docs] def get_session(self, request: Request, key: str | None = None) -> Session | AsyncSession: """Get the session for the given key.""" config = self.get_config(key) return self._get_session_from_request(request, config)
[docs] def get_async_session(self, request: Request, key: str | None = None) -> AsyncSession: """Get the async session for the given key.""" config = self.get_async_config(key) return self._get_session_from_request(request, config)
[docs] def get_sync_session(self, request: Request, key: str | None = None) -> Session: """Get the sync session for the given key.""" config = self.get_sync_config(key) return self._get_session_from_request(request, config)
[docs] def provide_session(self, key: str | None = None) -> Callable[[Request], Session | AsyncSession]: """Get the session for the given key.""" config = self.get_config(key) def _get_session(request: Request) -> Session | AsyncSession: return self._get_session_from_request(request, config) return _get_session
[docs] def provide_async_session(self, key: str | None = None) -> Callable[[Request], AsyncSession]: """Get the async session for the given key.""" config = self.get_async_config(key) def _get_session(request: Request) -> AsyncSession: return self._get_session_from_request(request, config) return _get_session
[docs] def provide_sync_session(self, key: str | None = None) -> Callable[[Request], Session]: """Get the sync session for the given key.""" config = self.get_sync_config(key) def _get_session(request: Request) -> Session: return self._get_session_from_request(request, config) return _get_session
[docs] def get_engine(self, key: str | None = None) -> Engine | AsyncEngine: # pragma: no cover """Get the engine for the given key.""" config = self.get_config(key) return config.get_engine()
[docs] def get_async_engine(self, key: str | None = None) -> AsyncEngine: """Get the async engine for the given key.""" config = self.get_async_config(key) return config.get_engine()
[docs] def get_sync_engine(self, key: str | None = None) -> Engine: """Get the sync engine for the given key.""" config = self.get_sync_config(key) return config.get_engine()
[docs] def provide_engine(self, key: str | None = None) -> Callable[[], Engine | AsyncEngine]: # pragma: no cover """Get the engine for the given key.""" config = self.get_config(key) def _get_engine() -> Engine | AsyncEngine: return config.get_engine() return _get_engine
[docs] def provide_async_engine(self, key: str | None = None) -> Callable[[], AsyncEngine]: # pragma: no cover """Get the async engine for the given key.""" config = self.get_async_config(key) def _get_engine() -> AsyncEngine: return config.get_engine() return _get_engine
[docs] def provide_sync_engine(self, key: str | None = None) -> Callable[[], Engine]: # pragma: no cover """Get the sync engine for the given key.""" config = self.get_sync_config(key) def _get_engine() -> Engine: return config.get_engine() return _get_engine