Source code for advanced_alchemy.mixins.unique

from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING, Any

from sqlalchemy import ColumnElement, select
from sqlalchemy.orm import declarative_mixin

from advanced_alchemy.exceptions import wrap_sqlalchemy_exception

if TYPE_CHECKING:
    from collections.abc import Hashable, Iterator

    from sqlalchemy import Select
    from sqlalchemy.ext.asyncio import AsyncSession
    from sqlalchemy.ext.asyncio.scoping import async_scoped_session
    from sqlalchemy.orm import Session
    from sqlalchemy.orm.scoping import scoped_session
    from typing_extensions import Self

__all__ = ("UniqueMixin",)


[docs] @declarative_mixin class UniqueMixin: """Mixin for instantiating objects while ensuring uniqueness on some field(s). This is a slightly modified implementation derived from https://github.com/sqlalchemy/sqlalchemy/wiki/UniqueObject """ @classmethod @contextmanager def _prevent_autoflush( cls, session: AsyncSession | async_scoped_session[AsyncSession] | Session | scoped_session[Session], ) -> Iterator[None]: with session.no_autoflush, wrap_sqlalchemy_exception(): yield @classmethod def _check_uniqueness( cls, cache: dict[tuple[type[Self], Hashable], Self] | None, session: AsyncSession | async_scoped_session[AsyncSession] | Session | scoped_session[Session], key: tuple[type[Self], Hashable], *args: Any, **kwargs: Any, ) -> tuple[dict[tuple[type[Self], Hashable], Self], Select[tuple[Self]], Self | None]: if cache is None: cache = {} setattr(session, "_unique_cache", cache) statement = select(cls).where(cls.unique_filter(*args, **kwargs)).limit(2) return cache, statement, cache.get(key)
[docs] @classmethod async def as_unique_async( cls, session: AsyncSession | async_scoped_session[AsyncSession], *args: Any, **kwargs: Any, ) -> Self: """Instantiate and return a unique object within the provided session based on the given arguments. If an object with the same unique identifier already exists in the session, it is returned from the cache. Args: session (AsyncSession | async_scoped_session[AsyncSession]): SQLAlchemy async session *args (Any): Values used to instantiate the instance if no duplicate exists **kwargs (Any): Values used to instantiate the instance if no duplicate exists Returns: Self: The unique object instance. """ key = cls, cls.unique_hash(*args, **kwargs) cache, statement, obj = cls._check_uniqueness( getattr(session, "_unique_cache", None), session, key, *args, **kwargs, ) if obj: return obj with cls._prevent_autoflush(session): if (obj := (await session.execute(statement)).scalar_one_or_none()) is None: session.add(obj := cls(*args, **kwargs)) cache[key] = obj return obj
[docs] @classmethod def as_unique_sync( cls, session: Session | scoped_session[Session], *args: Any, **kwargs: Any, ) -> Self: """Instantiate and return a unique object within the provided session based on the given arguments. If an object with the same unique identifier already exists in the session, it is returned from the cache. Args: session (Session | scoped_session[Session]): SQLAlchemy sync session *args (Any): Values used to instantiate the instance if no duplicate exists **kwargs (Any): Values used to instantiate the instance if no duplicate exists Returns: Self: The unique object instance. """ key = cls, cls.unique_hash(*args, **kwargs) cache, statement, obj = cls._check_uniqueness( getattr(session, "_unique_cache", None), session, key, *args, **kwargs, ) if obj: return obj with cls._prevent_autoflush(session): if (obj := session.execute(statement).scalar_one_or_none()) is None: session.add(obj := cls(*args, **kwargs)) cache[key] = obj return obj
[docs] @classmethod def unique_hash(cls, *args: Any, **kwargs: Any) -> Hashable: """Generate a unique key based on the provided arguments. This method should be implemented in the subclass. Args: *args (Any): Values passed to the alternate classmethod constructors **kwargs (Any): Values passed to the alternate classmethod constructors Raises: NotImplementedError: If not implemented in the subclass. Returns: Hashable: Any hashable object. """ msg = "Implement this in subclass" raise NotImplementedError(msg)
[docs] @classmethod def unique_filter(cls, *args: Any, **kwargs: Any) -> ColumnElement[bool]: """Generate a filter condition for ensuring uniqueness. This method should be implemented in the subclass. Args: *args (Any): Values passed to the alternate classmethod constructors **kwargs (Any): Values passed to the alternate classmethod constructors Raises: NotImplementedError: If not implemented in the subclass. Returns: ColumnElement[bool]: Filter condition to establish the uniqueness. """ msg = "Implement this in subclass" raise NotImplementedError(msg)