Source code for advanced_alchemy.extensions.starlette.config

"""Configuration classes for Starlette integration.

This module provides configuration classes for integrating SQLAlchemy with Starlette applications,
including both synchronous and asynchronous database configurations.
"""

from __future__ import annotations

import contextlib
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable

from click import echo
from sqlalchemy.exc import OperationalError
from starlette.concurrency import run_in_threadpool
from starlette.middleware.base import BaseHTTPMiddleware
from typing_extensions import Literal

from advanced_alchemy._serialization import decode_json, encode_json
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config import EngineConfig as _EngineConfig
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig
from advanced_alchemy.config.sync import SQLAlchemySyncConfig as _SQLAlchemySyncConfig
from advanced_alchemy.service import schema_dump

if TYPE_CHECKING:
    from typing import Any

    from sqlalchemy.ext.asyncio import AsyncSession
    from sqlalchemy.orm import Session
    from starlette.applications import Starlette
    from starlette.middleware.base import RequestResponseEndpoint
    from starlette.requests import Request
    from starlette.responses import Response


def _make_unique_state_key(app: Starlette, key: str) -> str:  # pragma: no cover
    """Generates a unique state key for the Starlette application.

    Ensures that the key does not already exist in the application's state.

    Args:
        app (starlette.applications.Starlette): The Starlette application instance.
        key (str): The base key name.

    Returns:
        str: A unique key name.
    """
    i = 0
    while True:
        if not hasattr(app.state, key):
            return key
        key = f"{key}_{i}"
        i += i


def serializer(value: Any) -> str:
    """Serialize JSON field values.

    Args:
        value: Any JSON serializable value.

    Returns:
        str: JSON string representation of the value.
    """
    return encode_json(schema_dump(value))


[docs] @dataclass class EngineConfig(_EngineConfig): """Configuration for SQLAlchemy's Engine. This class extends the base EngineConfig with Starlette-specific JSON serialization options. For details see: https://docs.sqlalchemy.org/en/20/core/engines.html Attributes: json_deserializer: Callable for converting JSON strings to Python objects. json_serializer: Callable for converting Python objects to JSON strings. """ json_deserializer: Callable[[str], Any] = decode_json """For dialects that support the :class:`~sqlalchemy.types.JSON` datatype, this is a Python callable that will convert a JSON string to a Python object. But default, this uses the built-in serializers.""" json_serializer: Callable[[Any], str] = serializer """For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON. By default, By default, the built-in serializer is used."""
[docs] @dataclass class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig): """SQLAlchemy Async config for Starlette.""" app: Starlette | None = None """The Starlette application instance.""" commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual" """The commit mode to use for database sessions.""" engine_key: str = "db_engine" """Key to use for the dependency injection of database engines.""" session_key: str = "db_session" """Key to use for the dependency injection of database sessions.""" session_maker_key: str = "session_maker_class" """Key under which to store the SQLAlchemy :class:`sessionmaker <sqlalchemy.orm.sessionmaker>` in the application state instance. """ engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride] """Configuration for the SQLAlchemy engine. The configuration options are documented in the SQLAlchemy documentation. """
[docs] async def create_all_metadata(self) -> None: # pragma: no cover """Create all metadata tables in the database.""" if self.engine_instance is None: self.engine_instance = self.get_engine() async with self.engine_instance.begin() as conn: try: await conn.run_sync( metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all ) await conn.commit() except OperationalError as exc: echo(f" * Could not create target metadata. Reason: {exc}") else: echo(" * Created target metadata.")
[docs] def init_app(self, app: Starlette) -> None: """Initialize the Starlette application with this configuration. Args: app: The Starlette application instance. """ self.app = app self.bind_key = self.bind_key or "default" _ = self.create_session_maker() self.session_key = _make_unique_state_key(app, f"advanced_alchemy_async_session_{self.session_key}") self.engine_key = _make_unique_state_key(app, f"advanced_alchemy_async_engine_{self.engine_key}") self.session_maker_key = _make_unique_state_key( app, f"advanced_alchemy_async_session_maker_{self.session_maker_key}" ) app.add_middleware(BaseHTTPMiddleware, dispatch=self.middleware_dispatch) app.add_event_handler("startup", self.on_startup) # pyright: ignore[reportUnknownMemberType] app.add_event_handler("shutdown", self.on_shutdown) # pyright: ignore[reportUnknownMemberType]
[docs] async def on_startup(self) -> None: """Initialize the Starlette application with this configuration.""" if self.create_all: await self.create_all_metadata()
[docs] def create_session_maker(self) -> Callable[[], AsyncSession]: """Get a session maker. If none exists yet, create one. Returns: Callable[[], Session]: Session factory used by the plugin. """ if self.session_maker: return self.session_maker session_kws = self.session_config_dict if self.engine_instance is None: self.engine_instance = self.get_engine() if session_kws.get("bind") is None: session_kws["bind"] = self.engine_instance self.session_maker = self.session_maker_class(**session_kws) return self.session_maker
[docs] async def session_handler( self, session: AsyncSession, request: Request, response: Response ) -> None: # pragma: no cover """Handles the session after a request is processed. Applies the commit strategy and ensures the session is closed. Args: session (sqlalchemy.ext.asyncio.AsyncSession): The database session. request (starlette.requests.Request): The incoming HTTP request. response (starlette.responses.Response): The outgoing HTTP response. Returns: None """ try: if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004 self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004 ): await session.commit() else: await session.rollback() finally: await session.close() with contextlib.suppress(AttributeError, KeyError): delattr(request.state, self.session_key)
[docs] async def middleware_dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: # pragma: no cover """Middleware dispatch function to handle requests and responses. Processes the request, invokes the next middleware or route handler, and applies the session handler after the response is generated. Args: request (starlette.requests.Request): The incoming HTTP request. call_next (starlette.middleware.base.RequestResponseEndpoint): The next middleware or route handler. Returns: starlette.responses.Response: The HTTP response. """ response = await call_next(request) session: AsyncSession | None = getattr(request.state, self.session_key, None) if session is not None: await self.session_handler(session=session, request=request, response=response) return response
[docs] async def close_engine(self) -> None: # pragma: no cover """Close the engine.""" if self.engine_instance is not None: await self.engine_instance.dispose()
[docs] async def on_shutdown(self) -> None: # pragma: no cover """Handles the shutdown event by disposing of the SQLAlchemy engine. Ensures that all connections are properly closed during application shutdown. Returns: None """ await self.close_engine() if self.app is not None: with contextlib.suppress(AttributeError, KeyError): delattr(self.app.state, self.engine_key) delattr(self.app.state, self.session_maker_key) delattr(self.app.state, self.session_key)
[docs] @dataclass class SQLAlchemySyncConfig(_SQLAlchemySyncConfig): """SQLAlchemy Sync config for Starlette.""" app: Starlette | None = None """The Starlette application instance.""" commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual" """The commit mode to use for database sessions.""" engine_key: str = "db_engine" """Key to use for the dependency injection of database engines.""" session_key: str = "db_session" """Key to use for the dependency injection of database sessions.""" session_maker_key: str = "session_maker_class" """Key under which to store the SQLAlchemy :class:`sessionmaker <sqlalchemy.orm.sessionmaker>` in the application state instance. """ engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride] """Configuration for the SQLAlchemy engine. The configuration options are documented in the SQLAlchemy documentation. """
[docs] async def create_all_metadata(self) -> None: # pragma: no cover """Create all metadata tables in the database.""" if self.engine_instance is None: self.engine_instance = self.get_engine() with self.engine_instance.begin() as conn: try: await run_in_threadpool( metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all, conn ) except OperationalError as exc: echo(f" * Could not create target metadata. Reason: {exc}")
[docs] def init_app(self, app: Starlette) -> None: """Initialize the Starlette application with this configuration. Args: app: The Starlette application instance. """ self.app = app self.bind_key = self.bind_key or "default" self.session_key = _make_unique_state_key(app, f"advanced_alchemy_sync_session_{self.session_key}") self.engine_key = _make_unique_state_key(app, f"advanced_alchemy_sync_engine_{self.engine_key}") self.session_maker_key = _make_unique_state_key( app, f"advanced_alchemy_sync_session_maker_{self.session_maker_key}" ) _ = self.create_session_maker() app.add_middleware(BaseHTTPMiddleware, dispatch=self.middleware_dispatch) app.add_event_handler("startup", self.on_startup) # pyright: ignore[reportUnknownMemberType] app.add_event_handler("shutdown", self.on_shutdown) # pyright: ignore[reportUnknownMemberType]
[docs] async def on_startup(self) -> None: """Initialize the Starlette application with this configuration.""" if self.create_all: await self.create_all_metadata()
[docs] def create_session_maker(self) -> Callable[[], Session]: """Get a session maker. If none exists yet, create one. Returns: Callable[[], Session]: Session factory used by the plugin. """ if self.session_maker: return self.session_maker session_kws = self.session_config_dict if self.engine_instance is None: self.engine_instance = self.get_engine() if session_kws.get("bind") is None: session_kws["bind"] = self.engine_instance self.session_maker = self.session_maker_class(**session_kws) return self.session_maker
[docs] async def session_handler(self, session: Session, request: Request, response: Response) -> None: # pragma: no cover """Handles the session after a request is processed. Applies the commit strategy and ensures the session is closed. Args: session (sqlalchemy.orm.Session | sqlalchemy.ext.asyncio.AsyncSession): The database session. request (starlette.requests.Request): The incoming HTTP request. response (starlette.responses.Response): The outgoing HTTP response. Returns: None """ try: if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004 self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004 ): await run_in_threadpool(session.commit) else: await run_in_threadpool(session.rollback) finally: await run_in_threadpool(session.close) with contextlib.suppress(AttributeError, KeyError): delattr(request.state, self.session_key)
[docs] async def middleware_dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: # pragma: no cover """Middleware dispatch function to handle requests and responses. Processes the request, invokes the next middleware or route handler, and applies the session handler after the response is generated. Args: request (starlette.requests.Request): The incoming HTTP request. call_next (starlette.middleware.base.RequestResponseEndpoint): The next middleware or route handler. Returns: starlette.responses.Response: The HTTP response. """ response = await call_next(request) session: Session | None = getattr(request.state, self.session_key, None) if session is not None: await self.session_handler(session=session, request=request, response=response) return response
[docs] async def close_engine(self) -> None: # pragma: no cover """Close the engines.""" if self.engine_instance is not None: await run_in_threadpool(self.engine_instance.dispose)
[docs] async def on_shutdown(self) -> None: # pragma: no cover """Handles the shutdown event by disposing of the SQLAlchemy engine. Ensures that all connections are properly closed during application shutdown. Returns: None """ await self.close_engine() if self.app is not None: with contextlib.suppress(AttributeError, KeyError): delattr(self.app.state, self.engine_key) delattr(self.app.state, self.session_key) delattr(self.app.state, self.session_maker_key)