Source code for advanced_alchemy.extensions.flask.config

"""Configuration classes for Flask integration.

This module provides configuration classes for integrating SQLAlchemy with Flask applications,
including both synchronous and asynchronous database configurations.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast

from click import echo
from flask import g, has_request_context
from sqlalchemy.exc import OperationalError
from typing_extensions import Literal

from advanced_alchemy._serialization import decode_json, encode_json
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config import EngineConfig as _EngineConfig
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig
from advanced_alchemy.config.sync import SQLAlchemySyncConfig as _SQLAlchemySyncConfig
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.service import schema_dump

if TYPE_CHECKING:
    from typing import Any

    from flask import Flask, Response
    from sqlalchemy.ext.asyncio import AsyncSession
    from sqlalchemy.orm import Session

    from advanced_alchemy.utils.portals import Portal

__all__ = ("EngineConfig", "SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig")

ConfigT = TypeVar("ConfigT", bound="Union[SQLAlchemySyncConfig, SQLAlchemyAsyncConfig]")


def serializer(value: Any) -> str:
    """Serialize JSON field values.

    Calls the `:func:schema_dump` function to convert the value to a built-in before encoding.

    Args:
        value: Any JSON serializable value.

    Returns:
        str: JSON string representation of the value.
    """

    return encode_json(schema_dump(value))


[docs] @dataclass class EngineConfig(_EngineConfig): """Configuration for SQLAlchemy's Engine. This class extends the base EngineConfig with Flask-specific JSON serialization options. For details see: https://docs.sqlalchemy.org/en/20/core/engines.html Attributes: json_deserializer: Callable for converting JSON strings to Python objects. json_serializer: Callable for converting Python objects to JSON strings. """ json_deserializer: Callable[[str], Any] = decode_json """For dialects that support the :class:`~sqlalchemy.types.JSON` datatype, this is a Python callable that will convert a JSON string to a Python object.""" json_serializer: Callable[[Any], str] = serializer """For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON."""
[docs] @dataclass class SQLAlchemySyncConfig(_SQLAlchemySyncConfig): """Flask-specific synchronous SQLAlchemy configuration. Attributes: app: The Flask application instance. commit_mode: The commit mode to use for database sessions. """ app: Flask | None = None """The Flask application instance.""" commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual" """The commit mode to use for database sessions."""
[docs] def create_session_maker(self) -> Callable[[], Session]: """Get a session maker. If none exists yet, create one. Returns: Callable[[], Session]: Session factory used by the plugin. """ if self.session_maker: return self.session_maker session_kws = self.session_config_dict if self.engine_instance is None: self.engine_instance = self.get_engine() if session_kws.get("bind") is None: session_kws["bind"] = self.engine_instance self.session_maker = self.session_maker_class(**session_kws) return self.session_maker
[docs] def init_app(self, app: Flask, portal: Portal | None = None) -> None: """Initialize the Flask application with this configuration. Args: app: The Flask application instance. portal: The portal to use for thread-safe communication. Unused in synchronous configurations. """ self.app = app self.bind_key = self.bind_key or "default" if self.create_all: self.create_all_metadata() if self.commit_mode != "manual": self._setup_session_handling(app)
def _setup_session_handling(self, app: Flask) -> None: """Set up the session handling for the Flask application. Args: app: The Flask application instance. """ @app.after_request def handle_db_session(response: Response) -> Response: # pyright: ignore[reportUnusedFunction] """Commit the session if the response meets the commit criteria.""" if not has_request_context(): return response db_session = cast("Optional[Session]", g.pop(f"advanced_alchemy_session_{self.bind_key}", None)) if db_session is not None: if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004 self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004 ): db_session.commit() db_session.close() return response
[docs] def close_engines(self, portal: Portal) -> None: """Close the engines. Args: portal: The portal to use for thread-safe communication. """ if self.engine_instance is not None: self.engine_instance.dispose()
[docs] def create_all_metadata(self) -> None: # pragma: no cover """Create all metadata tables in the database.""" if self.engine_instance is None: self.engine_instance = self.get_engine() with self.engine_instance.begin() as conn: try: metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all(conn) except OperationalError as exc: echo(f" * Could not create target metadata. Reason: {exc}") else: echo(" * Created target metadata.")
[docs] @dataclass class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig): """Flask-specific asynchronous SQLAlchemy configuration. Attributes: app: The Flask application instance. commit_mode: The commit mode to use for database sessions. """ app: Flask | None = None """The Flask application instance.""" commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual" """The commit mode to use for database sessions."""
[docs] def create_session_maker(self) -> Callable[[], AsyncSession]: """Get a session maker. If none exists yet, create one. Returns: Callable[[], AsyncSession]: Session factory used by the plugin. """ if self.session_maker: return self.session_maker session_kws = self.session_config_dict if self.engine_instance is None: self.engine_instance = self.get_engine() if session_kws.get("bind") is None: session_kws["bind"] = self.engine_instance self.session_maker = self.session_maker_class(**session_kws) return self.session_maker
[docs] def init_app(self, app: Flask, portal: Portal | None = None) -> None: """Initialize the Flask application with this configuration. Args: app: The Flask application instance. portal: The portal to use for thread-safe communication. Raises: ImproperConfigurationError: If portal is not provided for async configuration. """ self.app = app self.bind_key = self.bind_key or "default" if portal is None: msg = "Portal is required for asynchronous configurations" raise ImproperConfigurationError(msg) if self.create_all: _ = portal.call(self.create_all_metadata) self._setup_session_handling(app, portal)
def _setup_session_handling(self, app: Flask, portal: Portal) -> None: """Set up the session handling for the Flask application. Args: app: The Flask application instance. portal: The portal to use for thread-safe communication. """ @app.after_request def handle_db_session(response: Response) -> Response: # pyright: ignore[reportUnusedFunction] """Commit the session if the response meets the commit criteria.""" if not has_request_context(): return response db_session = cast("Optional[AsyncSession]", g.pop(f"advanced_alchemy_session_{self.bind_key}", None)) if db_session is not None: p = getattr(db_session, "_session_portal", None) or portal if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004 self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004 ): _ = p.call(db_session.commit) _ = p.call(db_session.close) return response @app.teardown_appcontext def close_db_session(_: BaseException | None) -> None: # pyright: ignore[reportUnusedFunction] """Close the session at the end of the request.""" db_session = cast("Optional[AsyncSession]", g.pop(f"advanced_alchemy_session_{self.bind_key}", None)) if db_session is not None: p = getattr(db_session, "_session_portal", None) or portal _ = p.call(db_session.close)
[docs] def close_engines(self, portal: Portal) -> None: """Close the engines. Args: portal: The portal to use for thread-safe communication. """ if self.engine_instance is not None: _ = portal.call(self.engine_instance.dispose)
[docs] async def create_all_metadata(self) -> None: # pragma: no cover """Create all metadata tables in the database.""" if self.engine_instance is None: self.engine_instance = self.get_engine() async with self.engine_instance.begin() as conn: try: await conn.run_sync( metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all ) await conn.commit() except OperationalError as exc: echo(f" * Could not create target metadata. Reason: {exc}") else: echo(" * Created target metadata.")