Source code for advanced_alchemy.extensions.sanic.extension

from collections.abc import AsyncGenerator, Generator, Sequence
from contextlib import asynccontextmanager, contextmanager
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, overload

from sanic import Request, Sanic

from advanced_alchemy.exceptions import ImproperConfigurationError, MissingDependencyError
from advanced_alchemy.extensions.sanic.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig

try:
    from sanic_ext import Extend
    from sanic_ext.extensions.base import Extension

    SANIC_INSTALLED = True
except ModuleNotFoundError:  # pragma: no cover
    SANIC_INSTALLED = False  # pyright: ignore[reportConstantRedefinition]
    Extension = type("Extension", (), {})  # type: ignore
    Extend = type("Extend", (), {})  # type: ignore

if TYPE_CHECKING:
    from sanic import Sanic
    from sqlalchemy import Engine
    from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
    from sqlalchemy.orm import Session


__all__ = ("AdvancedAlchemy",)


[docs] class AdvancedAlchemy(Extension): # type: ignore[no-untyped-call] # pyright: ignore[reportGeneralTypeIssues,reportUntypedBaseClass] """Sanic extension for integrating Advanced Alchemy with SQLAlchemy. Args: config: One or more configurations for SQLAlchemy. app: The Sanic application instance. """ name = "AdvancedAlchemy"
[docs] def __init__( self, *, sqlalchemy_config: Union[ "SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig", Sequence[Union["SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig"]], ], sanic_app: Optional["Sanic[Any, Any]"] = None, ) -> None: if not SANIC_INSTALLED: # pragma: no cover msg = "Could not locate either Sanic or Sanic Extensions. Both libraries must be installed to use Advanced Alchemy. Try: pip install sanic[ext]" raise MissingDependencyError(msg) self._config = sqlalchemy_config if isinstance(sqlalchemy_config, Sequence) else [sqlalchemy_config] self._mapped_configs: dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]] = self.map_configs() self._app = sanic_app self._initialized = False if self._app is not None: self.register(self._app)
[docs] def register(self, sanic_app: "Sanic[Any, Any]") -> None: """Initialize the extension with the given Sanic app.""" self._app = sanic_app Extend.register(self) # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] self._initialized = True
@property def sanic_app(self) -> "Sanic[Any, Any]": """The Sanic app.""" if self._app is None: # pragma: no cover msg = "AdvancedAlchemy has not been initialized with a Sanic app." raise ImproperConfigurationError(msg) return self._app @property def sqlalchemy_config(self) -> Sequence[Union["SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig"]]: """Current Advanced Alchemy configuration.""" return self._config
[docs] def startup(self, bootstrap: "Extend") -> None: # pyright: ignore[reportUnknownParameterType,reportInvalidTypeForm] """Advanced Alchemy Sanic extension startup hook. Args: bootstrap (sanic_ext.Extend): The Sanic extension bootstrap. """ for config in self.sqlalchemy_config: config.init_app(self.sanic_app, bootstrap) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
[docs] def map_configs(self) -> dict[str, Union["SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig"]]: """Maps the configs to the session bind keys.""" mapped_configs: dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]] = {} for config in self.sqlalchemy_config: if config.bind_key is None: config.bind_key = "default" mapped_configs[config.bind_key] = config return mapped_configs
[docs] def get_config(self, key: Optional[str] = None) -> Union["SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig"]: """Get the config for the given key.""" if key is None: key = "default" if key == "default" and len(self.sqlalchemy_config) == 1: key = self.sqlalchemy_config[0].bind_key or "default" config = self._mapped_configs.get(key) if config is None: # pragma: no cover msg = f"Config with key {key} not found" raise ImproperConfigurationError(msg) return config
[docs] def get_async_config(self, key: Optional[str] = None) -> "SQLAlchemyAsyncConfig": """Get the async config for the given key.""" config = self.get_config(key) if not isinstance(config, SQLAlchemyAsyncConfig): # pragma: no cover msg = "Expected an async config, but got a sync config" raise ImproperConfigurationError(msg) return config
[docs] def get_sync_config(self, key: Optional[str] = None) -> "SQLAlchemySyncConfig": """Get the sync config for the given key.""" config = self.get_config(key) if not isinstance(config, SQLAlchemySyncConfig): # pragma: no cover msg = "Expected a sync config, but got an async config" raise ImproperConfigurationError(msg) return config
[docs] @asynccontextmanager async def with_async_session( self, key: Optional[str] = None ) -> AsyncGenerator["AsyncSession", None]: # pragma: no cover """Context manager for getting an async session.""" config = self.get_async_config(key) async with config.get_session() as session: yield session
[docs] @contextmanager def with_sync_session(self, key: Optional[str] = None) -> Generator["Session", None]: # pragma: no cover """Context manager for getting a sync session.""" config = self.get_sync_config(key) with config.get_session() as session: yield session
@overload @staticmethod def _get_session_from_request(request: "Request", config: "SQLAlchemyAsyncConfig") -> "AsyncSession": ... @overload @staticmethod def _get_session_from_request(request: "Request", config: "SQLAlchemySyncConfig") -> "Session": ... @staticmethod def _get_session_from_request( request: "Request", config: Union["SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig"], # pragma: no cover ) -> Union["Session", "AsyncSession"]: # pragma: no cover """Get the session for the given key.""" session = getattr(request.ctx, config.session_key, None) if session is None: setattr(request.ctx, config.session_key, config.get_session()) return cast("Union[Session, AsyncSession]", session)
[docs] def get_session( self, request: "Request", key: Optional[str] = None ) -> Union["Session", "AsyncSession"]: # pragma: no cover """Get the session for the given key.""" config = self.get_config(key) return self._get_session_from_request(request, config)
[docs] def get_async_session(self, request: "Request", key: Optional[str] = None) -> "AsyncSession": # pragma: no cover """Get the async session for the given key.""" config = self.get_async_config(key) return self._get_session_from_request(request, config)
[docs] def get_sync_session(self, request: "Request", key: Optional[str] = None) -> "Session": # pragma: no cover """Get the sync session for the given key.""" config = self.get_sync_config(key) return self._get_session_from_request(request, config)
[docs] def provide_session( self, key: Optional[str] = None ) -> Callable[["Request"], Union["Session", "AsyncSession"]]: # pragma: no cover """Get the session for the given key.""" config = self.get_config(key) def _get_session(request: "Request") -> Union["Session", "AsyncSession"]: return self._get_session_from_request(request, config) return _get_session
[docs] def provide_async_session( self, key: Optional[str] = None ) -> Callable[["Request"], "AsyncSession"]: # pragma: no cover """Get the async session for the given key.""" config = self.get_async_config(key) def _get_session(request: Request) -> "AsyncSession": return self._get_session_from_request(request, config) return _get_session
[docs] def provide_sync_session(self, key: Optional[str] = None) -> Callable[[Request], "Session"]: # pragma: no cover """Get the sync session for the given key.""" config = self.get_sync_config(key) def _get_session(request: Request) -> "Session": return self._get_session_from_request(request, config) return _get_session
[docs] def get_engine(self, key: Optional[str] = None) -> Union["Engine", "AsyncEngine"]: # pragma: no cover """Get the engine for the given key.""" config = self.get_config(key) return config.get_engine()
[docs] def get_async_engine(self, key: Optional[str] = None) -> "AsyncEngine": # pragma: no cover """Get the async engine for the given key.""" config = self.get_async_config(key) return config.get_engine()
[docs] def get_sync_engine(self, key: Optional[str] = None) -> "Engine": # pragma: no cover """Get the sync engine for the given key.""" config = self.get_sync_config(key) return config.get_engine()
[docs] def provide_engine( self, key: Optional[str] = None ) -> Callable[[], Union["Engine", "AsyncEngine"]]: # pragma: no cover """Get the engine for the given key.""" config = self.get_config(key) def _get_engine() -> Union["Engine", "AsyncEngine"]: return config.get_engine() return _get_engine
[docs] def provide_async_engine(self, key: Optional[str] = None) -> Callable[[], "AsyncEngine"]: # pragma: no cover """Get the async engine for the given key.""" config = self.get_async_config(key) def _get_engine() -> "AsyncEngine": return config.get_engine() return _get_engine
[docs] def provide_sync_engine(self, key: Optional[str] = None) -> Callable[[], "Engine"]: # pragma: no cover """Get the sync engine for the given key.""" config = self.get_sync_config(key) def _get_engine() -> "Engine": return config.get_engine() return _get_engine
[docs] def add_session_dependency( self, session_type: type[Union["Session", "AsyncSession"]], key: Optional[str] = None ) -> None: """Add a session dependency to the Sanic app.""" self.sanic_app.ext.add_dependency(session_type, self.provide_session(key)) # pyright: ignore[reportUnknownMemberType]
[docs] def add_engine_dependency( self, engine_type: type[Union["Engine", "AsyncEngine"]], key: Optional[str] = None ) -> None: """Add an engine dependency to the Sanic app.""" self.sanic_app.ext.add_dependency(engine_type, self.provide_engine(key)) # pyright: ignore[reportUnknownMemberType]