"""Configuration classes for Starlette integration.
This module provides configuration classes for integrating SQLAlchemy with Starlette applications,
including both synchronous and asynchronous database configurations.
"""
from __future__ import annotations
import contextlib
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable
from click import echo
from sqlalchemy.exc import OperationalError
from starlette.concurrency import run_in_threadpool
from starlette.middleware.base import BaseHTTPMiddleware
from typing_extensions import Literal
from advanced_alchemy._serialization import decode_json, encode_json
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config import EngineConfig as _EngineConfig
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig
from advanced_alchemy.config.sync import SQLAlchemySyncConfig as _SQLAlchemySyncConfig
from advanced_alchemy.service import schema_dump
if TYPE_CHECKING:
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from starlette.applications import Starlette
from starlette.middleware.base import RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
def _make_unique_state_key(app: Starlette, key: str) -> str: # pragma: no cover
"""Generates a unique state key for the Starlette application.
Ensures that the key does not already exist in the application's state.
Args:
app (starlette.applications.Starlette): The Starlette application instance.
key (str): The base key name.
Returns:
str: A unique key name.
"""
i = 0
while True:
if not hasattr(app.state, key):
return key
key = f"{key}_{i}"
i += i
def serializer(value: Any) -> str:
"""Serialize JSON field values.
Args:
value: Any JSON serializable value.
Returns:
str: JSON string representation of the value.
"""
return encode_json(schema_dump(value))
[docs]
@dataclass
class EngineConfig(_EngineConfig):
"""Configuration for SQLAlchemy's Engine.
This class extends the base EngineConfig with Starlette-specific JSON serialization options.
For details see: https://docs.sqlalchemy.org/en/20/core/engines.html
Attributes:
json_deserializer: Callable for converting JSON strings to Python objects.
json_serializer: Callable for converting Python objects to JSON strings.
"""
json_deserializer: Callable[[str], Any] = decode_json
"""For dialects that support the :class:`~sqlalchemy.types.JSON` datatype, this is a Python callable that will
convert a JSON string to a Python object. But default, this uses the built-in serializers."""
json_serializer: Callable[[Any], str] = serializer
"""For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON.
By default, By default, the built-in serializer is used."""
[docs]
@dataclass
class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig):
"""SQLAlchemy Async config for Starlette."""
app: Starlette | None = None
"""The Starlette application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
engine_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
session_maker_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker <sqlalchemy.orm.sessionmaker>` in the application state instance.
"""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
[docs]
def init_app(self, app: Starlette) -> None:
"""Initialize the Starlette application with this configuration.
Args:
app: The Starlette application instance.
"""
self.app = app
self.bind_key = self.bind_key or "default"
_ = self.create_session_maker()
self.session_key = _make_unique_state_key(app, f"advanced_alchemy_async_session_{self.session_key}")
self.engine_key = _make_unique_state_key(app, f"advanced_alchemy_async_engine_{self.engine_key}")
self.session_maker_key = _make_unique_state_key(
app, f"advanced_alchemy_async_session_maker_{self.session_maker_key}"
)
app.add_middleware(BaseHTTPMiddleware, dispatch=self.middleware_dispatch)
app.add_event_handler("startup", self.on_startup) # pyright: ignore[reportUnknownMemberType]
app.add_event_handler("shutdown", self.on_shutdown) # pyright: ignore[reportUnknownMemberType]
[docs]
async def on_startup(self) -> None:
"""Initialize the Starlette application with this configuration."""
if self.create_all:
await self.create_all_metadata()
[docs]
def create_session_maker(self) -> Callable[[], AsyncSession]:
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], Session]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
[docs]
async def session_handler(
self, session: AsyncSession, request: Request, response: Response
) -> None: # pragma: no cover
"""Handles the session after a request is processed.
Applies the commit strategy and ensures the session is closed.
Args:
session (sqlalchemy.ext.asyncio.AsyncSession):
The database session.
request (starlette.requests.Request):
The incoming HTTP request.
response (starlette.responses.Response):
The outgoing HTTP response.
Returns:
None
"""
try:
if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004
):
await session.commit()
else:
await session.rollback()
finally:
await session.close()
with contextlib.suppress(AttributeError, KeyError):
delattr(request.state, self.session_key)
[docs]
async def middleware_dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response: # pragma: no cover
"""Middleware dispatch function to handle requests and responses.
Processes the request, invokes the next middleware or route handler, and
applies the session handler after the response is generated.
Args:
request (starlette.requests.Request): The incoming HTTP request.
call_next (starlette.middleware.base.RequestResponseEndpoint):
The next middleware or route handler.
Returns:
starlette.responses.Response: The HTTP response.
"""
response = await call_next(request)
session: AsyncSession | None = getattr(request.state, self.session_key, None)
if session is not None:
await self.session_handler(session=session, request=request, response=response)
return response
[docs]
async def close_engine(self) -> None: # pragma: no cover
"""Close the engine."""
if self.engine_instance is not None:
await self.engine_instance.dispose()
[docs]
async def on_shutdown(self) -> None: # pragma: no cover
"""Handles the shutdown event by disposing of the SQLAlchemy engine.
Ensures that all connections are properly closed during application shutdown.
Returns:
None
"""
await self.close_engine()
if self.app is not None:
with contextlib.suppress(AttributeError, KeyError):
delattr(self.app.state, self.engine_key)
delattr(self.app.state, self.session_maker_key)
delattr(self.app.state, self.session_key)
[docs]
@dataclass
class SQLAlchemySyncConfig(_SQLAlchemySyncConfig):
"""SQLAlchemy Sync config for Starlette."""
app: Starlette | None = None
"""The Starlette application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
engine_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
session_maker_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker <sqlalchemy.orm.sessionmaker>` in the application state instance.
"""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
[docs]
def init_app(self, app: Starlette) -> None:
"""Initialize the Starlette application with this configuration.
Args:
app: The Starlette application instance.
"""
self.app = app
self.bind_key = self.bind_key or "default"
self.session_key = _make_unique_state_key(app, f"advanced_alchemy_sync_session_{self.session_key}")
self.engine_key = _make_unique_state_key(app, f"advanced_alchemy_sync_engine_{self.engine_key}")
self.session_maker_key = _make_unique_state_key(
app, f"advanced_alchemy_sync_session_maker_{self.session_maker_key}"
)
_ = self.create_session_maker()
app.add_middleware(BaseHTTPMiddleware, dispatch=self.middleware_dispatch)
app.add_event_handler("startup", self.on_startup) # pyright: ignore[reportUnknownMemberType]
app.add_event_handler("shutdown", self.on_shutdown) # pyright: ignore[reportUnknownMemberType]
[docs]
async def on_startup(self) -> None:
"""Initialize the Starlette application with this configuration."""
if self.create_all:
await self.create_all_metadata()
[docs]
def create_session_maker(self) -> Callable[[], Session]:
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], Session]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
[docs]
async def session_handler(self, session: Session, request: Request, response: Response) -> None: # pragma: no cover
"""Handles the session after a request is processed.
Applies the commit strategy and ensures the session is closed.
Args:
session (sqlalchemy.orm.Session | sqlalchemy.ext.asyncio.AsyncSession):
The database session.
request (starlette.requests.Request):
The incoming HTTP request.
response (starlette.responses.Response):
The outgoing HTTP response.
Returns:
None
"""
try:
if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004
):
await run_in_threadpool(session.commit)
else:
await run_in_threadpool(session.rollback)
finally:
await run_in_threadpool(session.close)
with contextlib.suppress(AttributeError, KeyError):
delattr(request.state, self.session_key)
[docs]
async def middleware_dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response: # pragma: no cover
"""Middleware dispatch function to handle requests and responses.
Processes the request, invokes the next middleware or route handler, and
applies the session handler after the response is generated.
Args:
request (starlette.requests.Request): The incoming HTTP request.
call_next (starlette.middleware.base.RequestResponseEndpoint):
The next middleware or route handler.
Returns:
starlette.responses.Response: The HTTP response.
"""
response = await call_next(request)
session: Session | None = getattr(request.state, self.session_key, None)
if session is not None:
await self.session_handler(session=session, request=request, response=response)
return response
[docs]
async def close_engine(self) -> None: # pragma: no cover
"""Close the engines."""
if self.engine_instance is not None:
await run_in_threadpool(self.engine_instance.dispose)
[docs]
async def on_shutdown(self) -> None: # pragma: no cover
"""Handles the shutdown event by disposing of the SQLAlchemy engine.
Ensures that all connections are properly closed during application shutdown.
Returns:
None
"""
await self.close_engine()
if self.app is not None:
with contextlib.suppress(AttributeError, KeyError):
delattr(self.app.state, self.engine_key)
delattr(self.app.state, self.session_key)
delattr(self.app.state, self.session_maker_key)