"""Configuration classes for Starlette integration.
This module provides configuration classes for integrating SQLAlchemy with Starlette applications,
including both synchronous and asynchronous database configurations.
"""
import contextlib
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
from click import echo
from sqlalchemy.exc import OperationalError
from starlette.concurrency import run_in_threadpool
from starlette.middleware.base import BaseHTTPMiddleware
from typing_extensions import Literal
from advanced_alchemy._serialization import decode_json, encode_json
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config import EngineConfig as _EngineConfig
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig
from advanced_alchemy.config.sync import SQLAlchemySyncConfig as _SQLAlchemySyncConfig
from advanced_alchemy.service import schema_dump
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from starlette.applications import Starlette
from starlette.middleware.base import RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
def _make_unique_state_key(app: "Starlette", key: str) -> str: # pragma: no cover
"""Generates a unique state key for the Starlette application.
Ensures that the key does not already exist in the application's state.
Args:
app (starlette.applications.Starlette): The Starlette application instance.
key (str): The base key name.
Returns:
str: A unique key name.
"""
i = 0
while True:
if not hasattr(app.state, key):
return key
key = f"{key}_{i}"
i += i
def serializer(value: Any) -> str:
"""Serialize JSON field values.
Args:
value: Any JSON serializable value.
Returns:
str: JSON string representation of the value.
"""
return encode_json(schema_dump(value))
[docs]
@dataclass
class EngineConfig(_EngineConfig):
"""Configuration for SQLAlchemy's Engine.
This class extends the base EngineConfig with Starlette-specific JSON serialization options.
For details see: https://docs.sqlalchemy.org/en/20/core/engines.html
Attributes:
json_deserializer: Callable for converting JSON strings to Python objects.
json_serializer: Callable for converting Python objects to JSON strings.
"""
json_deserializer: Callable[[str], Any] = decode_json
"""For dialects that support the :class:`~sqlalchemy.types.JSON` datatype, this is a Python callable that will
convert a JSON string to a Python object. But default, this uses the built-in serializers."""
json_serializer: Callable[[Any], str] = serializer
"""For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON.
By default, By default, the built-in serializer is used."""
[docs]
@dataclass
class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig):
"""SQLAlchemy Async config for Starlette."""
app: "Optional[Starlette]" = None
"""The Starlette application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
engine_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
session_maker_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker <sqlalchemy.orm.sessionmaker>` in the application state instance.
"""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
[docs]
def init_app(self, app: "Starlette") -> None:
"""Initialize the Starlette application with this configuration.
Args:
app: The Starlette application instance.
"""
self.app = app
self.bind_key = self.bind_key or "default"
_ = self.create_session_maker()
self.session_key = _make_unique_state_key(app, f"advanced_alchemy_async_session_{self.session_key}")
self.engine_key = _make_unique_state_key(app, f"advanced_alchemy_async_engine_{self.engine_key}")
self.session_maker_key = _make_unique_state_key(
app, f"advanced_alchemy_async_session_maker_{self.session_maker_key}"
)
app.add_middleware(BaseHTTPMiddleware, dispatch=self.middleware_dispatch)
[docs]
async def on_startup(self) -> None:
"""Initialize the Starlette application with this configuration."""
if self.create_all:
await self.create_all_metadata()
[docs]
def create_session_maker(self) -> Callable[[], "AsyncSession"]:
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], Session]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
[docs]
async def session_handler(
self, session: "AsyncSession", request: "Request", response: "Response"
) -> None: # pragma: no cover
"""Handles the session after a request is processed.
Applies the commit strategy and ensures the session is closed.
Args:
session (sqlalchemy.ext.asyncio.AsyncSession):
The database session.
request (starlette.requests.Request):
The incoming HTTP request.
response (starlette.responses.Response):
The outgoing HTTP response.
Returns:
None
"""
try:
if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004
):
await session.commit()
else:
await session.rollback()
finally:
await session.close()
with contextlib.suppress(AttributeError, KeyError):
delattr(request.state, self.session_key)
[docs]
async def middleware_dispatch(
self, request: "Request", call_next: "RequestResponseEndpoint"
) -> "Response": # pragma: no cover
"""Middleware dispatch function to handle requests and responses.
Processes the request, invokes the next middleware or route handler, and
applies the session handler after the response is generated.
Args:
request (starlette.requests.Request): The incoming HTTP request.
call_next (starlette.middleware.base.RequestResponseEndpoint):
The next middleware or route handler.
Returns:
starlette.responses.Response: The HTTP response.
"""
response = await call_next(request)
session = cast("Optional[AsyncSession]", getattr(request.state, self.session_key, None))
if session is not None:
await self.session_handler(session=session, request=request, response=response)
return response
[docs]
async def close_engine(self) -> None: # pragma: no cover
"""Close the engine."""
if self.engine_instance is not None:
await self.engine_instance.dispose()
[docs]
async def on_shutdown(self) -> None: # pragma: no cover
"""Handles the shutdown event by disposing of the SQLAlchemy engine.
Ensures that all connections are properly closed during application shutdown.
Returns:
None
"""
await self.close_engine()
if self.app is not None:
with contextlib.suppress(AttributeError, KeyError):
delattr(self.app.state, self.engine_key)
delattr(self.app.state, self.session_maker_key)
delattr(self.app.state, self.session_key)
[docs]
@dataclass
class SQLAlchemySyncConfig(_SQLAlchemySyncConfig):
"""SQLAlchemy Sync config for Starlette."""
app: "Optional[Starlette]" = None
"""The Starlette application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
engine_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
session_maker_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker <sqlalchemy.orm.sessionmaker>` in the application state instance.
"""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
[docs]
def init_app(self, app: "Starlette") -> None:
"""Initialize the Starlette application with this configuration.
Args:
app: The Starlette application instance.
"""
self.app = app
self.bind_key = self.bind_key or "default"
self.session_key = _make_unique_state_key(app, f"advanced_alchemy_sync_session_{self.session_key}")
self.engine_key = _make_unique_state_key(app, f"advanced_alchemy_sync_engine_{self.engine_key}")
self.session_maker_key = _make_unique_state_key(
app, f"advanced_alchemy_sync_session_maker_{self.session_maker_key}"
)
_ = self.create_session_maker()
app.add_middleware(BaseHTTPMiddleware, dispatch=self.middleware_dispatch)
[docs]
async def on_startup(self) -> None:
"""Initialize the Starlette application with this configuration."""
if self.create_all:
await self.create_all_metadata()
[docs]
def create_session_maker(self) -> Callable[[], "Session"]:
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], Session]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
[docs]
async def session_handler(
self, session: "Session", request: "Request", response: "Response"
) -> None: # pragma: no cover
"""Handles the session after a request is processed.
Applies the commit strategy and ensures the session is closed.
Args:
session (sqlalchemy.orm.Session | sqlalchemy.ext.asyncio.AsyncSession):
The database session.
request (starlette.requests.Request):
The incoming HTTP request.
response (starlette.responses.Response):
The outgoing HTTP response.
Returns:
None
"""
try:
if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004
):
await run_in_threadpool(session.commit)
else:
await run_in_threadpool(session.rollback)
finally:
await run_in_threadpool(session.close)
with contextlib.suppress(AttributeError, KeyError):
delattr(request.state, self.session_key)
[docs]
async def middleware_dispatch(
self, request: "Request", call_next: "RequestResponseEndpoint"
) -> "Response": # pragma: no cover
"""Middleware dispatch function to handle requests and responses.
Processes the request, invokes the next middleware or route handler, and
applies the session handler after the response is generated.
Args:
request (starlette.requests.Request): The incoming HTTP request.
call_next (starlette.middleware.base.RequestResponseEndpoint):
The next middleware or route handler.
Returns:
starlette.responses.Response: The HTTP response.
"""
response = await call_next(request)
session = cast("Optional[Session]", getattr(request.state, self.session_key, None))
if session is not None:
await self.session_handler(session=session, request=request, response=response)
return response
[docs]
async def close_engine(self) -> None: # pragma: no cover
"""Close the engines."""
if self.engine_instance is not None:
await run_in_threadpool(self.engine_instance.dispose)
[docs]
async def on_shutdown(self) -> None: # pragma: no cover
"""Handles the shutdown event by disposing of the SQLAlchemy engine.
Ensures that all connections are properly closed during application shutdown.
Returns:
None
"""
await self.close_engine()
if self.app is not None:
with contextlib.suppress(AttributeError, KeyError):
delattr(self.app.state, self.engine_key)
delattr(self.app.state, self.session_maker_key)
delattr(self.app.state, self.session_key)