Source code for advanced_alchemy.extensions.sanic

from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, cast, overload

from sanic import HTTPResponse, Request, Sanic  # noqa: TCH002
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker

from advanced_alchemy.config.common import EngineT, SessionMakerT, SessionT
from advanced_alchemy.exceptions import MissingDependencyError

try:
    from sanic.helpers import Default, _default
    from sanic_ext import Extend
    from sanic_ext.extensions.base import Extension

    SANIC_INSTALLED = True
except ModuleNotFoundError:
    SANIC_INSTALLED = False
    Extension = type("Extension", (), {})  # type: ignore  # noqa: PGH003
    Extend = type("Extend", (), {})  # type: ignore  # noqa: PGH003
    Default = type("Default", (), {})  # type: ignore  # noqa: PGH003
    _default = Default()

if TYPE_CHECKING:
    from sqlalchemy.orm import Session, sessionmaker

    from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig
    from advanced_alchemy.config.sync import SQLAlchemySyncConfig
    from advanced_alchemy.config.types import CommitStrategy


[docs]class CommitStrategyExecutor(Protocol): async def __call__(self, *, session: Session | AsyncSession, response: HTTPResponse) -> None: ...
[docs]class SanicAdvancedAlchemy(Extension, Generic[EngineT, SessionT, SessionMakerT]): name = "AdvancedAlchemy" @overload def __init__( self: SanicAdvancedAlchemy[AsyncEngine, AsyncSession, async_sessionmaker[AsyncSession]], *, sqlalchemy_config: SQLAlchemyAsyncConfig, autocommit: CommitStrategy | None = None, counters: Default | bool = _default, # pyright: ignore[reportInvalidTypeForm] session_maker_key: str = "sessionmaker", engine_key: str = "engine", session_key: str = "session", ) -> None: ... @overload def __init__( self: SanicAdvancedAlchemy[Engine, Session, sessionmaker[Session]], *, sqlalchemy_config: SQLAlchemySyncConfig, autocommit: CommitStrategy | None = None, counters: Default | bool = _default, # pyright: ignore[reportInvalidTypeForm] session_maker_key: str = "sessionmaker", engine_key: str = "engine", session_key: str = "session", ) -> None: ...
[docs] def __init__( self: ( SanicAdvancedAlchemy[AsyncEngine, AsyncSession, async_sessionmaker[AsyncSession]] | SanicAdvancedAlchemy[Engine, Session, sessionmaker[Session]] ), *, sqlalchemy_config: SQLAlchemySyncConfig | SQLAlchemyAsyncConfig, autocommit: CommitStrategy | None = None, counters: Default | bool = _default, # pyright: ignore[reportInvalidTypeForm] session_maker_key: str = "sessionmaker", engine_key: str = "engine", session_key: str = "session", ) -> None: if not SANIC_INSTALLED: msg = "Could not locate either Sanic or Sanic Extensions. Both libraries must be installed to use Advanced Alchemy. Try: pip install sanic[ext]" raise MissingDependencyError( msg, ) self.sqlalchemy_config = sqlalchemy_config self.engine_key = engine_key self.session_maker_key = session_maker_key self.session_key = session_key self.autocommit_strategy = autocommit self._commit_strategies: dict[CommitStrategy, CommitStrategyExecutor] = { "always": self._commit_strategy_always, "match_status": self._commit_strategy_match_status, } self.counters = counters self.engine = ( self.sqlalchemy_config.engine_instance if self.sqlalchemy_config.engine_instance is not None else self.sqlalchemy_config.get_engine() ) self.session_maker = self.sqlalchemy_config.create_session_maker() session_maker = cast("SessionMakerT", self.session_maker) self.session_class = session_maker.class_ self.app: Sanic
async def _do_commit(self, session: Session | AsyncSession) -> None: if not isinstance(session, AsyncSession): loop = asyncio.get_event_loop() await loop.run_in_executor(None, session.commit) else: await session.commit() async def _do_rollback(self, session: Session | AsyncSession) -> None: if not isinstance(session, AsyncSession): loop = asyncio.get_event_loop() await loop.run_in_executor(None, session.rollback) else: await session.rollback() async def _do_close(self, session: Session | AsyncSession) -> None: if not isinstance(session, AsyncSession): loop = asyncio.get_event_loop() await loop.run_in_executor(None, session.close) else: await session.close() async def _commit_strategy_always(self, *, session: Session | AsyncSession, response: HTTPResponse) -> None: await self._do_commit(session) async def _commit_strategy_match_status(self, *, session: Session | AsyncSession, response: HTTPResponse) -> None: if 200 <= response.status < 300: # noqa: PLR2004 await self._do_commit(session) else: await self._do_rollback(session) async def session_handler(self, session: Session | AsyncSession, request: Request, response: HTTPResponse) -> None: try: if self.autocommit_strategy: await self._commit_strategies[self.autocommit_strategy](session=session, response=response) finally: await self._do_close(session) delattr(request.ctx, self.session_key) def get_engine(self) -> EngineT: engine = getattr(self.app.ctx, self.engine_key, None) if engine is not None: return cast(EngineT, engine) engine = cast(EngineT, self.engine) setattr(self.app.ctx, self.engine_key, engine) return engine def get_sessionmaker(self) -> Callable[[], SessionT]: return cast(Callable[[], SessionT], self.session_maker) def get_session(self, request: Request) -> SessionT: session = getattr(request.ctx, self.session_key, None) if session is not None: return cast(SessionT, session) session = cast("SessionT", self.session_maker()) setattr(request.ctx, self.session_key, session) return session def get_engine_from_request(self, request: Request) -> EngineT: return cast("EngineT", getattr(request.app.ctx, self.engine_key, None)) def get_sessionmaker_from_request(self, request: Request) -> SessionMakerT: return cast("SessionMakerT", getattr(request.app.ctx, self.session_maker_key, None)) def get_session_from_request(self, request: Request) -> SessionT: return cast("SessionT", getattr(request.ctx, self.session_key, None))
[docs] def startup(self, bootstrap: Extend) -> None: """Advanced Alchemy Sanic extension startup hook.""" @self.app.before_server_start async def on_startup(_: Any) -> None: setattr(self.app.ctx, self.engine_key, self.engine) setattr(self.app.ctx, self.session_maker_key, self.session_maker) bootstrap.add_dependency( type(self.engine), self.get_engine_from_request, ) bootstrap.add_dependency( type(self.session_maker), self.get_sessionmaker_from_request, ) bootstrap.add_dependency( self.session_class, self.get_session_from_request, ) @self.app.after_server_stop async def on_shutdown(_: Any) -> None: if isinstance(self.engine, Engine): loop = asyncio.get_event_loop() await loop.run_in_executor(None, self.engine.dispose) else: await self.engine.dispose() if hasattr(self.app.ctx, self.engine_key): delattr(self.app.ctx, self.engine_key) if hasattr(self.app.ctx, self.session_maker_key): delattr(self.app.ctx, self.session_maker_key) @self.app.middleware("request") async def on_request(request: Request) -> None: session: Session | AsyncSession | None = getattr(request.ctx, self.session_key, None) if session is None: session = self.get_session(request) setattr(request.ctx, self.session_key, session) @self.app.middleware("response") # type: ignore[arg-type] async def on_response(request: Request, response: HTTPResponse) -> None: session: Session | AsyncSession | None = getattr(request.ctx, self.session_key, None) if session is not None: await self.session_handler(session=session, request=request, response=response)