import contextlib
from collections.abc import AsyncGenerator, Callable, Generator, Sequence
from contextlib import asynccontextmanager, contextmanager
from typing import (
    TYPE_CHECKING,
    Any,
    Optional,
    Union,
    cast,
    overload,
)
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.requests import Request
from advanced_alchemy._listeners import set_async_context
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.extensions.starlette.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
if TYPE_CHECKING:
    from sqlalchemy import Engine
    from sqlalchemy.ext.asyncio import AsyncEngine
    from sqlalchemy.orm import Session
    from starlette.applications import Starlette
[docs]
class AdvancedAlchemy:
    """AdvancedAlchemy integration for Starlette applications.
    This class manages SQLAlchemy sessions and engine lifecycle within a Starlette application.
    It provides middleware for handling transactions based on commit strategies.
    Args:
        config (advanced_alchemy.config.asyncio.SQLAlchemyAsyncConfig | advanced_alchemy.config.sync.SQLAlchemySyncConfig):
            The SQLAlchemy configuration.
        app (starlette.applications.Starlette | None):
            The Starlette application instance. Defaults to None.
    """
[docs]
    def __init__(
        self,
        config: Union[
            SQLAlchemyAsyncConfig, SQLAlchemySyncConfig, Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]
        ],
        app: Optional["Starlette"] = None,
    ) -> None:
        self._config = config if isinstance(config, Sequence) else [config]
        self._mapped_configs: dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]] = self.map_configs()
        self._app = cast("Optional[Starlette]", None)
        if app is not None:
            self.init_app(app) 
    @property
    def config(self) -> Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]:
        """Current Advanced Alchemy configuration."""
        return self._config
[docs]
    def init_app(self, app: "Starlette") -> None:
        """Initializes the Starlette application with SQLAlchemy engine and sessionmaker.
        Sets up middleware and shutdown handlers for managing the database engine.
        Args:
            app (starlette.applications.Starlette): The Starlette application instance.
        Raises:
            advanced_alchemy.exceptions.ImproperConfigurationError:
                If the application is not initialized.
        """
        self._app = app
        unique_bind_keys = {config.bind_key for config in self.config}
        if len(unique_bind_keys) != len(self.config):  # pragma: no cover
            msg = "Please ensure that each config has a unique name for the `bind_key` attribute.  The default is `default` and can only be bound to a single engine."
            raise ImproperConfigurationError(msg)
        for config in self.config:
            config.init_app(app)
        app.state.advanced_alchemy = self
        original_lifespan = app.router.lifespan_context
        @asynccontextmanager
        async def wrapped_lifespan(app: "Starlette") -> AsyncGenerator[Any, None]:  # pragma: no cover
            async with self.lifespan(app), original_lifespan(app) as state:
                yield state
        app.router.lifespan_context = wrapped_lifespan 
[docs]
    @asynccontextmanager
    async def lifespan(self, app: "Starlette") -> AsyncGenerator[Any, None]:  # pragma: no cover
        """Context manager for lifespan events.
        Args:
            app: The starlette application.
        Yields:
            None
        """
        await self.on_startup()
        try:
            yield
        finally:
            await self.on_shutdown() 
    @property
    def app(self) -> "Starlette":  # pragma: no cover
        """Returns the Starlette application instance.
        Raises:
            advanced_alchemy.exceptions.ImproperConfigurationError:
                If the application is not initialized.
        Returns:
            starlette.applications.Starlette: The Starlette application instance.
        """
        if self._app is None:  # pragma: no cover
            msg = "Application not initialized. Did you forget to call init_app?"
            raise ImproperConfigurationError(msg)
        return self._app
[docs]
    async def on_startup(self) -> None:  # pragma: no cover
        """Initializes the database."""
        for config in self.config:
            await config.on_startup() 
[docs]
    async def on_shutdown(self) -> None:  # pragma: no cover
        """Handles the shutdown event by disposing of the SQLAlchemy engine.
        Ensures that all connections are properly closed during application shutdown.
        """
        for config in self.config:
            await config.on_shutdown()
        with contextlib.suppress(AttributeError, KeyError):
            delattr(self.app.state, "advanced_alchemy") 
[docs]
    def map_configs(self) -> dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]:
        """Maps the configs to the session bind keys.
        Returns:
            A dictionary of config bind keys to configs.
        """
        mapped_configs: dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]] = {}
        for config in self.config:
            if config.bind_key is None:
                config.bind_key = "default"
            mapped_configs[config.bind_key] = config
        return mapped_configs 
[docs]
    def get_config(self, key: Optional[str] = None) -> Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]:
        """Get the config for the given key.
        Args:
            key: The key to get the config for.
        Raises:
            advanced_alchemy.exceptions.ImproperConfigurationError:
                If the config is not found.
        Returns:
            The config for the given key.
        """
        if key is None:
            key = "default"
        if key == "default" and len(self.config) == 1:
            key = self.config[0].bind_key or "default"
        config = self._mapped_configs.get(key)
        if config is None:  # pragma: no cover
            msg = f"Config with key {key} not found"
            raise ImproperConfigurationError(msg)
        return config 
[docs]
    def get_async_config(self, key: Optional[str] = None) -> SQLAlchemyAsyncConfig:
        """Get the async config for the given key.
        Raises:
            advanced_alchemy.exceptions.ImproperConfigurationError:
                If the config is not found.
        Returns:
            The async config for the given key.
        """
        config = self.get_config(key)
        if not isinstance(config, SQLAlchemyAsyncConfig):  # pragma: no cover
            msg = "Expected an async config, but got a sync config"
            raise ImproperConfigurationError(msg)
        return config 
[docs]
    def get_sync_config(self, key: Optional[str] = None) -> SQLAlchemySyncConfig:
        """Get the sync config for the given key.
        Raises:
            advanced_alchemy.exceptions.ImproperConfigurationError:
                If the config is not found.
        Returns:
            The sync config for the given key.
        """
        config = self.get_config(key)
        if not isinstance(config, SQLAlchemySyncConfig):  # pragma: no cover
            msg = "Expected a sync config, but got an async config"
            raise ImproperConfigurationError(msg)
        return config 
[docs]
    @asynccontextmanager
    async def with_async_session(
        self, key: Optional[str] = None
    ) -> AsyncGenerator["AsyncSession", None]:  # pragma: no cover
        """Context manager for getting an async session.
        Yields:
            The async session for the given key.
        """
        config = self.get_async_config(key)
        async with config.get_session() as session:
            yield session 
[docs]
    @contextmanager
    def with_sync_session(self, key: Optional[str] = None) -> Generator["Session", None]:  # pragma: no cover
        """Context manager for getting a sync session.
        Yields:
            The sync session for the given key.
        """
        config = self.get_sync_config(key)
        with config.get_session() as session:
            yield session 
    @overload
    @staticmethod
    def _get_session_from_request(request: Request, config: SQLAlchemyAsyncConfig) -> "AsyncSession": ...
    @overload
    @staticmethod
    def _get_session_from_request(request: Request, config: SQLAlchemySyncConfig) -> "Session": ...
    @staticmethod
    def _get_session_from_request(
        request: Request,
        config: Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig],  # pragma: no cover
    ) -> Union["Session", "AsyncSession"]:  # pragma: no cover
        """Get the session for the given key.
        Args:
            request: The request object.
            config: The config object.
        Returns:
            The session for the given key.
        """
        session = getattr(request.state, config.session_key, None)
        if session is None:
            session = config.create_session_maker()()
            setattr(request.state, config.session_key, session)
        set_async_context(isinstance(session, AsyncSession))
        return session
[docs]
    def get_session(
        self, request: Request, key: Optional[str] = None
    ) -> Union["Session", "AsyncSession"]:  # pragma: no cover
        """Get the session for the given key.
        Args:
            request: The request object.
            key: The key to get the session for.
        Returns:
            The session for the given key.
        """
        config = self.get_config(key)
        return self._get_session_from_request(request, config) 
[docs]
    def get_async_session(self, request: Request, key: Optional[str] = None) -> "AsyncSession":  # pragma: no cover
        """Get the async session for the given key.
        Args:
            request: The request object.
            key: The key to get the session for.
        Returns:
            The async session for the given key.
        """
        config = self.get_async_config(key)
        return self._get_session_from_request(request, config) 
[docs]
    def get_sync_session(self, request: Request, key: Optional[str] = None) -> "Session":  # pragma: no cover
        """Get the sync session for the given key.
        Args:
            request: The request object.
            key: The key to get the session for.
        Returns:
            The sync session for the given key.
        """
        config = self.get_sync_config(key)
        return self._get_session_from_request(request, config) 
[docs]
    def provide_session(
        self, key: Optional[str] = None
    ) -> Callable[[Request], Union["Session", "AsyncSession"]]:  # pragma: no cover
        """Get the session for the given key.
        Args:
            key: The key to get the session for.
        Returns:
            The session for the given key.
        """
        config = self.get_config(key)
        def _get_session(request: Request) -> Union["Session", "AsyncSession"]:
            set_async_context(isinstance(config, SQLAlchemyAsyncConfig))
            return self._get_session_from_request(request, config)
        return _get_session 
[docs]
    def provide_async_session(
        self, key: Optional[str] = None
    ) -> Callable[[Request], "AsyncSession"]:  # pragma: no cover
        """Get the async session for the given key.
        Args:
            key: The key to get the session for.
        Returns:
            The async session for the given key.
        """
        config = self.get_async_config(key)
        def _get_session(request: Request) -> "AsyncSession":
            set_async_context(True)
            return self._get_session_from_request(request, config)
        return _get_session 
[docs]
    def provide_sync_session(self, key: Optional[str] = None) -> Callable[[Request], "Session"]:  # pragma: no cover
        """Get the sync session for the given key.
        Args:
            key: The key to get the session for.
        Returns:
            The sync session for the given key.
        """
        config = self.get_sync_config(key)
        def _get_session(request: Request) -> "Session":
            set_async_context(False)
            return self._get_session_from_request(request, config)
        return _get_session 
[docs]
    def get_engine(self, key: Optional[str] = None) -> Union["Engine", "AsyncEngine"]:  # pragma: no cover
        """Get the engine for the given key.
        Args:
            key: The key to get the engine for.
        Returns:
            The engine for the given key.
        """
        config = self.get_config(key)
        return config.get_engine() 
[docs]
    def get_async_engine(self, key: Optional[str] = None) -> "AsyncEngine":  # pragma: no cover
        """Get the async engine for the given key.
        Args:
            key: The key to get the engine for.
        Returns:
            The async engine for the given key.
        """
        config = self.get_async_config(key)
        return config.get_engine() 
[docs]
    def get_sync_engine(self, key: Optional[str] = None) -> "Engine":  # pragma: no cover
        """Get the sync engine for the given key.
        Args:
            key: The key to get the engine for.
        Returns:
            The sync engine for the given key.
        """
        config = self.get_sync_config(key)
        return config.get_engine() 
[docs]
    def provide_engine(
        self, key: Optional[str] = None
    ) -> Callable[[], Union["Engine", "AsyncEngine"]]:  # pragma: no cover
        """Get the engine for the given key.
        Args:
            key: The key to get the engine for.
        Returns:
            The engine for the given key.
        """
        config = self.get_config(key)
        def _get_engine() -> Union["Engine", "AsyncEngine"]:
            return config.get_engine()
        return _get_engine 
[docs]
    def provide_async_engine(self, key: Optional[str] = None) -> Callable[[], "AsyncEngine"]:  # pragma: no cover
        """Get the async engine for the given key.
        Args:
            key: The key to get the engine for.
        Returns:
            The async engine for the given key.
        """
        config = self.get_async_config(key)
        def _get_engine() -> "AsyncEngine":
            return config.get_engine()
        return _get_engine 
[docs]
    def provide_sync_engine(self, key: Optional[str] = None) -> Callable[[], "Engine"]:  # pragma: no cover
        """Get the sync engine for the given key.
        Args:
            key: The key to get the engine for.
        Returns:
            The sync engine for the given key.
        """
        config = self.get_sync_config(key)
        def _get_engine() -> "Engine":
            return config.get_engine()
        return _get_engine