Source code for advanced_alchemy.repository._util

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, Sequence, Union, cast, overload

from sqlalchemy import (
    Select,
)
from sqlalchemy.orm import (
    InstrumentedAttribute,
    MapperProperty,
    RelationshipProperty,
    joinedload,
    lazyload,
    selectinload,
)
from sqlalchemy.orm.strategy_options import (
    _AbstractLoad,  # pyright: ignore[reportPrivateUsage]  # pyright: ignore[reportPrivateUsage]
)
from sqlalchemy.sql import ColumnElement, ColumnExpressionArgument
from sqlalchemy.sql.base import ExecutableOption
from typing_extensions import TypeAlias

from advanced_alchemy.exceptions import ErrorMessages
from advanced_alchemy.exceptions import wrap_sqlalchemy_exception as _wrap_sqlalchemy_exception
from advanced_alchemy.filters import (
    InAnyFilter,
    PaginationFilter,
    StatementFilter,
    StatementTypeT,
)
from advanced_alchemy.repository.typing import ModelT, OrderingPair

if TYPE_CHECKING:
    from sqlalchemy import (
        Delete,
        Dialect,
        Update,
    )
    from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate

    from advanced_alchemy.base import ModelProtocol


WhereClauseT = ColumnExpressionArgument[bool]
SingleLoad: TypeAlias = Union[
    _AbstractLoad,
    Literal["*"],
    InstrumentedAttribute[Any],
    RelationshipProperty[Any],
    MapperProperty[Any],
]
LoadCollection: TypeAlias = Sequence[Union[SingleLoad, Sequence[SingleLoad]]]
ExecutableOptions: TypeAlias = Sequence[ExecutableOption]
LoadSpec: TypeAlias = Union[LoadCollection, SingleLoad, ExecutableOption, ExecutableOptions]

OrderByT: TypeAlias = Union[
    str,
    InstrumentedAttribute[Any],
    RelationshipProperty[Any],
]

# NOTE: For backward compatibility with Litestar - this is imported from here within the litestar codebase.
wrap_sqlalchemy_exception = _wrap_sqlalchemy_exception

DEFAULT_ERROR_MESSAGE_TEMPLATES: ErrorMessages = {
    "integrity": "There was a data validation error during processing",
    "foreign_key": "A foreign key is missing or invalid",
    "multiple_rows": "Multiple matching rows found",
    "duplicate_key": "A record matching the supplied data already exists.",
    "other": "There was an error during data processing",
    "check_constraint": "The data failed a check constraint during processing",
}
"""Default error messages for repository errors."""


[docs] def get_instrumented_attr( model: type[ModelProtocol], key: str | InstrumentedAttribute[Any], ) -> InstrumentedAttribute[Any]: """Get an instrumented attribute from a model. Args: model: SQLAlchemy model class. key: Either a string attribute name or an :class:`sqlalchemy.orm.InstrumentedAttribute`. Returns: :class:`sqlalchemy.orm.InstrumentedAttribute`: The instrumented attribute from the model. """ if isinstance(key, str): return cast("InstrumentedAttribute[Any]", getattr(model, key)) return key
[docs] def model_from_dict(model: type[ModelT], **kwargs: Any) -> ModelT: """Create an ORM model instance from a dictionary of attributes. Args: model: The SQLAlchemy model class to instantiate. **kwargs: Keyword arguments containing model attribute values. Returns: ModelT: A new instance of the model populated with the provided values. """ data = { column_name: kwargs[column_name] for column_name in model.__mapper__.columns.keys() # noqa: SIM118 # pyright: ignore[reportUnknownMemberType] if column_name in kwargs } return model(**data)
def get_abstract_loader_options( loader_options: LoadSpec | None, default_loader_options: list[_AbstractLoad] | None = None, default_options_have_wildcards: bool = False, merge_with_default: bool = True, inherit_lazy_relationships: bool = True, cycle_count: int = 0, ) -> tuple[list[_AbstractLoad], bool]: """Generate SQLAlchemy loader options for eager loading relationships. Args: loader_options :class:`~advanced_alchemy.repository.typing.LoadSpec`|:class:`None` Specification for how to load relationships. Can be: - None: Use defaults - :class:`sqlalchemy.orm.strategy_options._AbstractLoad`: Direct SQLAlchemy loader option - :class:`sqlalchemy.orm.InstrumentedAttribute`: Model relationship attribute - :class:`sqlalchemy.orm.RelationshipProperty`: SQLAlchemy relationship - str: "*" for wildcard loading - :class:`typing.Sequence` of the above default_loader_options: :class:`typing.Sequence` of :class:`sqlalchemy.orm.strategy_options._AbstractLoad` loader options to start with. default_options_have_wildcards: Whether the default options contain wildcards. merge_with_default: Whether to merge the default options with the loader options. inherit_lazy_relationships: Whether to inherit the ``lazy`` configuration from the model's relationships. cycle_count: Number of times this function has been called recursively. Returns: tuple[:class:`list`[:class:`sqlalchemy.orm.strategy_options._AbstractLoad`], bool]: A tuple containing: - :class:`list` of :class:`sqlalchemy.orm.strategy_options._AbstractLoad` SQLAlchemy loader option objects - Boolean indicating if any wildcard loaders are present """ loads: list[_AbstractLoad] = [] if cycle_count == 0 and not inherit_lazy_relationships: loads.append(lazyload("*")) if cycle_count == 0 and merge_with_default and default_loader_options is not None: loads.extend(default_loader_options) options_have_wildcards = default_options_have_wildcards if loader_options is None: return (loads, options_have_wildcards) if isinstance(loader_options, _AbstractLoad): return ([loader_options], options_have_wildcards) if isinstance(loader_options, InstrumentedAttribute): loader_options = [loader_options.property] if isinstance(loader_options, RelationshipProperty): class_ = loader_options.class_attribute return ( [selectinload(class_)] if loader_options.uselist else [joinedload(class_, innerjoin=loader_options.innerjoin)], options_have_wildcards if loader_options.uselist else True, ) if isinstance(loader_options, str) and loader_options == "*": options_have_wildcards = True return ([joinedload("*")], options_have_wildcards) if isinstance(loader_options, (list, tuple)): for attribute in loader_options: # pyright: ignore[reportUnknownVariableType] if isinstance(attribute, (list, tuple)): load_chain, options_have_wildcards = get_abstract_loader_options( loader_options=attribute, # pyright: ignore[reportUnknownArgumentType] default_options_have_wildcards=options_have_wildcards, inherit_lazy_relationships=inherit_lazy_relationships, merge_with_default=merge_with_default, cycle_count=cycle_count + 1, ) loader = load_chain[-1] for sub_load in load_chain[-2::-1]: loader = sub_load.options(loader) loads.append(loader) else: load_chain, options_have_wildcards = get_abstract_loader_options( loader_options=attribute, # pyright: ignore[reportUnknownArgumentType] default_options_have_wildcards=options_have_wildcards, inherit_lazy_relationships=inherit_lazy_relationships, merge_with_default=merge_with_default, cycle_count=cycle_count + 1, ) loads.extend(load_chain) return (loads, options_have_wildcards)
[docs] class FilterableRepositoryProtocol(Protocol[ModelT]): """Protocol defining the interface for filterable repositories. This protocol defines the required attributes and methods that any filterable repository implementation must provide. Type Parameters: ModelT: :class:`~advanced_alchemy.base.ModelProtocol` The SQLAlchemy model type this repository handles. Attributes: model_type: :class:`~advanced_alchemy.base.ModelProtocol` The SQLAlchemy model class this repository manages. """ model_type: type[ModelT]
[docs] class FilterableRepository(FilterableRepositoryProtocol[ModelT]): """Default implementation of a filterable repository. Provides core filtering, ordering and pagination functionality for SQLAlchemy models. Type Parameters: ModelT: :class:`~advanced_alchemy.base.ModelProtocol` The SQLAlchemy model type this repository handles. """ model_type: type[ModelT] """The SQLAlchemy model class this repository manages.""" prefer_any_dialects: tuple[str] | None = ("postgresql",) """List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``.""" order_by: list[OrderingPair] | OrderingPair | None = None """List or single :class:`~advanced_alchemy.repository.typing.OrderingPair` to use for sorting.""" _prefer_any: bool = False """Whether to prefer ANY() over IN() in queries.""" _dialect: Dialect """The SQLAlchemy :class:`sqlalchemy.dialects.Dialect` being used.""" @overload def _apply_filters( self, *filters: StatementFilter | ColumnElement[bool], apply_pagination: bool = True, statement: Select[tuple[ModelT]], ) -> Select[tuple[ModelT]]: ... @overload def _apply_filters( self, *filters: StatementFilter | ColumnElement[bool], apply_pagination: bool = True, statement: Delete, ) -> Delete: ... @overload def _apply_filters( self, *filters: StatementFilter | ColumnElement[bool], apply_pagination: bool = True, statement: ReturningDelete[tuple[ModelT]] | ReturningUpdate[tuple[ModelT]], ) -> ReturningDelete[tuple[ModelT]] | ReturningUpdate[tuple[ModelT]]: ... @overload def _apply_filters( self, *filters: StatementFilter | ColumnElement[bool], apply_pagination: bool = True, statement: Update, ) -> Update: ... def _apply_filters( self, *filters: StatementFilter | ColumnElement[bool], apply_pagination: bool = True, statement: StatementTypeT, ) -> StatementTypeT: """Apply filters to a SQL statement. Args: *filters: Filter conditions to apply. apply_pagination: Whether to apply pagination filters. statement: The base SQL statement to filter. Returns: StatementTypeT: The filtered SQL statement. """ for filter_ in filters: if isinstance(filter_, (PaginationFilter,)): if apply_pagination: statement = filter_.append_to_statement(statement, self.model_type) elif isinstance(filter_, (InAnyFilter,)): statement = filter_.append_to_statement(statement, self.model_type) elif isinstance(filter_, ColumnElement): statement = cast("StatementTypeT", statement.where(filter_)) else: statement = filter_.append_to_statement(statement, self.model_type) return statement def _filter_select_by_kwargs( self, statement: StatementTypeT, kwargs: dict[Any, Any] | Iterable[tuple[Any, Any]], ) -> StatementTypeT: """Filter a statement using keyword arguments. Args: statement: :class:`sqlalchemy.sql.Select` The SQL statement to filter. kwargs: Dictionary or iterable of tuples containing filter criteria. Keys should be model attribute names, values are what to filter for. Returns: StatementTypeT: The filtered SQL statement. """ for key, val in dict(kwargs).items(): field = get_instrumented_attr(self.model_type, key) statement = cast("StatementTypeT", statement.where(field == val)) return statement def _apply_order_by( self, statement: StatementTypeT, order_by: list[tuple[str | InstrumentedAttribute[Any], bool]] | tuple[str | InstrumentedAttribute[Any], bool], ) -> StatementTypeT: """Apply ordering to a SQL statement. Args: statement: The SQL statement to order. order_by: Ordering specification. Either a single tuple or list of tuples where: - First element is the field name or :class:`sqlalchemy.orm.InstrumentedAttribute` to order by - Second element is a boolean indicating descending (True) or ascending (False) Returns: StatementTypeT: The ordered SQL statement. """ if not isinstance(order_by, list): order_by = [order_by] for order_field, is_desc in order_by: field = get_instrumented_attr(self.model_type, order_field) statement = self._order_by_attribute(statement, field, is_desc) return statement def _order_by_attribute( self, statement: StatementTypeT, field: InstrumentedAttribute[Any], is_desc: bool, ) -> StatementTypeT: """Apply ordering by a single attribute to a SQL statement. Args: statement: The SQL statement to order. field: The model attribute to order by. is_desc: Whether to order in descending (True) or ascending (False) order. Returns: StatementTypeT: The ordered SQL statement. """ if not isinstance(statement, Select): return statement return cast("StatementTypeT", statement.order_by(field.desc() if is_desc else field.asc()))